mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-20 05:26:57 +08:00
(improvement) Move out the datasource and merge the datasource with the model, and adapt the chat module (#423)
Co-authored-by: jolunoluo <jolunoluo@tencent.com>
This commit is contained in:
@@ -6,8 +6,11 @@ import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@@ -17,7 +20,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_necessaryDimension_groupBy() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 用户名, sum(访问次数) FROM 超音数 GROUP BY 用户名";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
@@ -26,7 +30,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_necessaryDimension_where() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 where 部门 = 'HR' group by 用户名";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 "
|
||||
+ "WHERE 部门 = 'HR' GROUP BY 用户名";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
@@ -36,7 +41,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_dimensionNotDrillDown_groupBy() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 页面, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 部门";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
@@ -45,7 +51,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_dimensionNotDrillDown_where() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 where 页面 = 'P1' group by 部门";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
@@ -54,7 +61,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_dimensionNotDrillDown_necessaryDimension() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 页面, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT sum(访问次数) FROM 超音数";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
@@ -63,7 +71,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_dimensionDrillDown() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 用户名, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名, 部门";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 用户名, 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 用户名, 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
@@ -72,7 +81,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_noDrillDownDimensionSetting() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 页面, 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 用户名";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
|
||||
mockModelSchemaNoDimensionSetting());
|
||||
String expectedProcessedSql = "SELECT 页面, 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 页面, 用户名";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
@@ -82,7 +92,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 访问次数 from 超音数";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
|
||||
mockModelSchemaNoDimensionSetting());
|
||||
String expectedProcessedSql = "select 访问次数 from 超音数";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
@@ -92,7 +103,8 @@ class MetricCheckPostProcessorTest {
|
||||
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
|
||||
mockModelSchemaNoDimensionSetting());
|
||||
String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
@@ -102,7 +114,7 @@ class MetricCheckPostProcessorTest {
|
||||
* 访问次数 drill down dimension is 用户名 and 部门
|
||||
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
|
||||
*/
|
||||
private ModelSchema mockModelSchema() {
|
||||
private SemanticSchema mockModelSchema() {
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
Set<SchemaElement> metrics = Sets.newHashSet(
|
||||
mockElement(1L, "访问次数", SchemaElementType.METRIC,
|
||||
@@ -113,10 +125,10 @@ class MetricCheckPostProcessorTest {
|
||||
);
|
||||
modelSchema.setMetrics(metrics);
|
||||
modelSchema.setDimensions(mockDimensions());
|
||||
return modelSchema;
|
||||
return new SemanticSchema(Lists.newArrayList(modelSchema));
|
||||
}
|
||||
|
||||
private ModelSchema mockModelSchemaNoDimensionSetting() {
|
||||
private SemanticSchema mockModelSchemaNoDimensionSetting() {
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
Set<SchemaElement> metrics = Sets.newHashSet(
|
||||
mockElement(1L, "访问次数", SchemaElementType.METRIC, Lists.newArrayList()),
|
||||
@@ -124,7 +136,7 @@ class MetricCheckPostProcessorTest {
|
||||
);
|
||||
modelSchema.setMetrics(metrics);
|
||||
modelSchema.setDimensions(mockDimensions());
|
||||
return modelSchema;
|
||||
return new SemanticSchema(Lists.newArrayList(modelSchema));
|
||||
}
|
||||
|
||||
private Set<SchemaElement> mockDimensions() {
|
||||
@@ -141,4 +153,10 @@ class MetricCheckPostProcessorTest {
|
||||
.relateSchemaElements(relateSchemaElements).build();
|
||||
}
|
||||
|
||||
private SemanticParseInfo mockParseInfo(String correctSql) {
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.test.ChatBizLauncher;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||
import com.tencent.supersonic.semantic.query.service.QueryService;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.slf4j.Logger;
|
||||
|
||||
@@ -1,32 +1,24 @@
|
||||
package com.tencent.supersonic.chat.test.context;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.config.DefaultMetric;
|
||||
import com.tencent.supersonic.chat.config.DefaultMetricInfo;
|
||||
import com.tencent.supersonic.chat.config.EntityInternalDetail;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.impl.ChatContextRepositoryImpl;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.service.impl.ConfigServiceImpl;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.chat.service.impl.ConfigServiceImpl;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||
import com.tencent.supersonic.semantic.model.domain.pojo.DimensionFilter;
|
||||
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
|
||||
import org.mockito.Mockito;
|
||||
@@ -34,6 +26,14 @@ import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@Configuration
|
||||
public class MockBeansConfiguration {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user