(improvement) Move out the datasource and merge the datasource with the model, and adapt the chat module (#423)

Co-authored-by: jolunoluo <jolunoluo@tencent.com>
This commit is contained in:
jipeli
2023-11-27 11:05:24 +08:00
committed by GitHub
parent 0534053ff9
commit 27bb1b322e
190 changed files with 3900 additions and 10561 deletions

View File

@@ -6,8 +6,11 @@ import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
@@ -17,7 +20,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_necessaryDimension_groupBy() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数) FROM 超音数 GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -26,7 +30,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_necessaryDimension_where() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 where 部门 = 'HR' group by 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 "
+ "WHERE 部门 = 'HR' GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
@@ -36,7 +41,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_dimensionNotDrillDown_groupBy() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -45,7 +51,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_dimensionNotDrillDown_where() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 where 页面 = 'P1' group by 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -54,7 +61,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_dimensionNotDrillDown_necessaryDimension() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT sum(访问次数) FROM 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -63,7 +71,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_dimensionDrillDown() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名, 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 用户名, 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -72,7 +81,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_noDrillDownDimensionSetting() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "SELECT 页面, 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 页面, 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
@@ -82,7 +92,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 访问次数 from 超音数";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 访问次数 from 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
@@ -92,7 +103,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
@@ -102,7 +114,7 @@ class MetricCheckPostProcessorTest {
* 访问次数 drill down dimension is 用户名 and 部门
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
*/
private ModelSchema mockModelSchema() {
private SemanticSchema mockModelSchema() {
ModelSchema modelSchema = new ModelSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC,
@@ -113,10 +125,10 @@ class MetricCheckPostProcessorTest {
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return modelSchema;
return new SemanticSchema(Lists.newArrayList(modelSchema));
}
private ModelSchema mockModelSchemaNoDimensionSetting() {
private SemanticSchema mockModelSchemaNoDimensionSetting() {
ModelSchema modelSchema = new ModelSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC, Lists.newArrayList()),
@@ -124,7 +136,7 @@ class MetricCheckPostProcessorTest {
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return modelSchema;
return new SemanticSchema(Lists.newArrayList(modelSchema));
}
private Set<SchemaElement> mockDimensions() {
@@ -141,4 +153,10 @@ class MetricCheckPostProcessorTest {
.relateSchemaElements(relateSchemaElements).build();
}
private SemanticParseInfo mockParseInfo(String correctSql) {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
return semanticParseInfo;
}
}

View File

@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter;
import com.tencent.supersonic.chat.test.ChatBizLauncher;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.query.service.QueryService;
import org.junit.runner.RunWith;
import org.slf4j.Logger;

View File

@@ -1,32 +1,24 @@
package com.tencent.supersonic.chat.test.context;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.config.DefaultMetricInfo;
import com.tencent.supersonic.chat.config.EntityInternalDetail;
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.chat.persistence.repository.impl.ChatContextRepositoryImpl;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.impl.ConfigServiceImpl;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
import com.tencent.supersonic.chat.service.impl.ConfigServiceImpl;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.model.domain.pojo.DimensionFilter;
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
import org.mockito.Mockito;
@@ -34,6 +26,14 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.when;
@Configuration
public class MockBeansConfiguration {