mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-20 13:44:19 +08:00
(improvement)(chat) Split chat into three modules: server, api, and core. (#594)
This commit is contained in:
@@ -1,23 +0,0 @@
|
||||
package com.tencent.supersonic.chat.application.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.sql.rule.TimeRangeParser;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
||||
class TimeRangeParserTest {
|
||||
|
||||
@Test
|
||||
void parse() {
|
||||
TimeRangeParser timeRangeParser = new TimeRangeParser();
|
||||
|
||||
QueryReq queryRequest = new QueryReq();
|
||||
ChatContext chatCtx = new ChatContext();
|
||||
|
||||
queryRequest.setQueryText("supersonic最近30天访问次数");
|
||||
timeRangeParser.parse(new QueryContext(queryRequest), chatCtx);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package com.tencent.supersonic.chat.application.parser.aggregate;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.sql.rule.AggregateTypeParser;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class AggregateTypeParserTest {
|
||||
|
||||
@Test
|
||||
void getAggregateParser() {
|
||||
AggregateTypeParser aggregateParser = new AggregateTypeParser();
|
||||
AggregateTypeEnum aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
|
||||
assertEquals(aggregateType, AggregateTypeEnum.MAX);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品pv");
|
||||
assertEquals(aggregateType, AggregateTypeEnum.COUNT);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品uv");
|
||||
assertEquals(aggregateType, AggregateTypeEnum.DISTINCT);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
|
||||
assertEquals(aggregateType, AggregateTypeEnum.MAX);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最小值");
|
||||
assertEquals(aggregateType, AggregateTypeEnum.MIN);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数平均值");
|
||||
assertEquals(aggregateType, AggregateTypeEnum.AVG);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数topN");
|
||||
assertEquals(aggregateType, AggregateTypeEnum.TOPN);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数汇总");
|
||||
assertEquals(aggregateType, AggregateTypeEnum.SUM);
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package com.tencent.supersonic.chat.application.search;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
||||
class SearchServiceImplTest {
|
||||
|
||||
@Test
|
||||
void search() {
|
||||
}
|
||||
|
||||
@Test
|
||||
void filerMetricsByModel() {
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class LoadRemoveServiceTest {
|
||||
|
||||
@Test
|
||||
void edit() {
|
||||
int compute = EditDistance.compute("在", "在你的身边");
|
||||
Assert.assertEquals(compute, 4);
|
||||
}
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class LLMResponseServiceTest {
|
||||
|
||||
@Test
|
||||
void deduplicationSqlWeight() {
|
||||
String sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
|
||||
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
llmResp.setSqlRespMap(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp2 = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
|
||||
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
llmResp2.setSqlRespMap(sqlWeight2);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp3 = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
|
||||
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
llmResp3.setSqlRespMap(sqlWeight3);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
class LLMS2SQLParserTest {
|
||||
|
||||
@Test
|
||||
void setFilter() {
|
||||
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
|
||||
|
||||
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
|
||||
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
|
||||
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
|
||||
SchemaValueMap value1 = new SchemaValueMap();
|
||||
value1.setBizName("杰伦");
|
||||
value1.setTechName("周杰伦");
|
||||
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
|
||||
schemaValueMaps.add(value1);
|
||||
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.model(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
|
||||
SchemaElement schemaElement2 = SchemaElement.builder()
|
||||
.bizName("publish_time")
|
||||
.name("发布时间")
|
||||
.model(2L)
|
||||
.build();
|
||||
dimensions.add(schemaElement2);
|
||||
|
||||
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
|
||||
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
SchemaElement metric = SchemaElement.builder()
|
||||
.bizName("play_count")
|
||||
.name("播放量")
|
||||
.model(2L)
|
||||
.build();
|
||||
metrics.add(metric);
|
||||
|
||||
when(mockSemanticSchema.getMetrics()).thenReturn(metrics);
|
||||
|
||||
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
|
||||
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
|
||||
}
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package com.tencent.supersonic.chat.processor;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.processor.parse.MetricCheckProcessor;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
class MetricCheckProcessorTest {
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_necessaryDimension_groupBy() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 用户名, sum(访问次数) FROM 超音数 GROUP BY 用户名";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_necessaryDimension_where() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 where 部门 = 'HR' group by 用户名";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 "
|
||||
+ "WHERE 部门 = 'HR' GROUP BY 用户名";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_dimensionNotDrillDown_groupBy() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 页面, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 部门";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_dimensionNotDrillDown_where() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 where 页面 = 'P1' group by 部门";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_dimensionNotDrillDown_necessaryDimension() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 页面, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT sum(访问次数) FROM 超音数";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_dimensionDrillDown() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 用户名, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名, 部门";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
|
||||
String expectedProcessedSql = "SELECT 用户名, 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 用户名, 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_noDrillDownDimensionSetting() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 页面, 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 用户名";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
|
||||
mockModelSchemaNoDimensionSetting());
|
||||
String expectedProcessedSql = "SELECT 页面, 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 页面, 用户名";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 访问次数 from 超音数";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
|
||||
mockModelSchemaNoDimensionSetting());
|
||||
String expectedProcessedSql = "select 访问次数 from 超音数";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
|
||||
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
|
||||
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
|
||||
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
|
||||
mockModelSchemaNoDimensionSetting());
|
||||
String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
/**
|
||||
* 访问次数 drill down dimension is 用户名 and 部门
|
||||
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
|
||||
*/
|
||||
private SemanticSchema mockModelSchema() {
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
Set<SchemaElement> metrics = Sets.newHashSet(
|
||||
mockElement(1L, "访问次数", SchemaElementType.METRIC,
|
||||
Lists.newArrayList(RelatedSchemaElement.builder().dimensionId(2L).isNecessary(false).build(),
|
||||
RelatedSchemaElement.builder().dimensionId(1L).isNecessary(false).build())),
|
||||
mockElement(2L, "访问用户数", SchemaElementType.METRIC,
|
||||
Lists.newArrayList(RelatedSchemaElement.builder().dimensionId(2L).isNecessary(true).build()))
|
||||
);
|
||||
modelSchema.setMetrics(metrics);
|
||||
modelSchema.setDimensions(mockDimensions());
|
||||
return new SemanticSchema(Lists.newArrayList(modelSchema));
|
||||
}
|
||||
|
||||
private SemanticSchema mockModelSchemaNoDimensionSetting() {
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
Set<SchemaElement> metrics = Sets.newHashSet(
|
||||
mockElement(1L, "访问次数", SchemaElementType.METRIC, Lists.newArrayList()),
|
||||
mockElement(2L, "访问用户数", SchemaElementType.METRIC, Lists.newArrayList())
|
||||
);
|
||||
modelSchema.setMetrics(metrics);
|
||||
modelSchema.setDimensions(mockDimensions());
|
||||
return new SemanticSchema(Lists.newArrayList(modelSchema));
|
||||
}
|
||||
|
||||
private Set<SchemaElement> mockDimensions() {
|
||||
return Sets.newHashSet(
|
||||
mockElement(1L, "用户名", SchemaElementType.DIMENSION, Lists.newArrayList()),
|
||||
mockElement(2L, "部门", SchemaElementType.DIMENSION, Lists.newArrayList()),
|
||||
mockElement(3L, "页面", SchemaElementType.DIMENSION, Lists.newArrayList())
|
||||
);
|
||||
}
|
||||
|
||||
private SchemaElement mockElement(Long id, String name, SchemaElementType type,
|
||||
List<RelatedSchemaElement> relateSchemaElements) {
|
||||
return SchemaElement.builder().id(id).name(name).type(type)
|
||||
.relatedSchemaElements(relateSchemaElements).build();
|
||||
}
|
||||
|
||||
private SemanticParseInfo mockParseInfo(String correctSql) {
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package com.tencent.supersonic.chat.test;
|
||||
|
||||
import org.mybatis.spring.annotation.MapperScan;
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.context.annotation.ComponentScan;
|
||||
|
||||
|
||||
@SpringBootApplication(scanBasePackages = {"com.tencent.supersonic.chat"})
|
||||
@ComponentScan("com.tencent.supersonic.chat")
|
||||
@MapperScan("com.tencent.supersonic.chat")
|
||||
public class ChatBizLauncher {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(ChatBizLauncher.class, args);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package com.tencent.supersonic.chat.test.context;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.config.DefaultMetric;
|
||||
import com.tencent.supersonic.chat.config.DefaultMetricInfo;
|
||||
import com.tencent.supersonic.chat.config.EntityInternalDetail;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.impl.ChatContextRepositoryImpl;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.service.impl.ConfigServiceImpl;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@Configuration
|
||||
public class MockBeansConfiguration {
|
||||
|
||||
public static void getOrCreateContextMock(ChatService chatService) {
|
||||
ChatContext context = new ChatContext();
|
||||
context.setChatId(1);
|
||||
when(chatService.getOrCreateContext(1)).thenReturn(context);
|
||||
}
|
||||
|
||||
public static void buildHttpSemanticServiceImpl(List<DimSchemaResp> dimensionDescs,
|
||||
List<MetricSchemaResp> metricDescs) {
|
||||
DefaultMetric defaultMetricDesc = new DefaultMetric();
|
||||
defaultMetricDesc.setUnit(3);
|
||||
defaultMetricDesc.setPeriod(Constants.DAY);
|
||||
List<DimSchemaResp> dimensionDescs1 = new ArrayList<>();
|
||||
DimSchemaResp dimensionDesc = new DimSchemaResp();
|
||||
dimensionDesc.setId(162L);
|
||||
dimensionDescs1.add(dimensionDesc);
|
||||
|
||||
DimSchemaResp dimensionDesc2 = new DimSchemaResp();
|
||||
dimensionDesc2.setId(163L);
|
||||
dimensionDesc2.setBizName("song_name");
|
||||
dimensionDesc2.setName("歌曲名");
|
||||
|
||||
EntityInternalDetail entityInternalDetailDesc = new EntityInternalDetail();
|
||||
entityInternalDetailDesc.setDimensionList(new ArrayList<>(Arrays.asList(dimensionDesc2)));
|
||||
MetricSchemaResp metricDesc = new MetricSchemaResp();
|
||||
metricDesc.setId(877L);
|
||||
metricDesc.setBizName("js_play_cnt");
|
||||
metricDesc.setName("结算播放量");
|
||||
entityInternalDetailDesc.setMetricList(new ArrayList<>(Arrays.asList(metricDesc)));
|
||||
|
||||
ModelSchemaResp modelSchemaDesc = new ModelSchemaResp();
|
||||
modelSchemaDesc.setDimensions(dimensionDescs);
|
||||
modelSchemaDesc.setMetrics(metricDescs);
|
||||
|
||||
}
|
||||
|
||||
public static void getModelExtendMock(ConfigServiceImpl configService) {
|
||||
DefaultMetricInfo defaultMetricInfo = new DefaultMetricInfo();
|
||||
defaultMetricInfo.setUnit(3);
|
||||
defaultMetricInfo.setPeriod(Constants.DAY);
|
||||
List<DefaultMetricInfo> defaultMetricInfos = new ArrayList<>();
|
||||
defaultMetricInfos.add(defaultMetricInfo);
|
||||
|
||||
ChatConfigResp chaConfigDesc = new ChatConfigResp();
|
||||
when(configService.fetchConfigByModelId(anyLong())).thenReturn(chaConfigDesc);
|
||||
}
|
||||
|
||||
public static void dimensionDescBuild(DimensionService dimensionService, List<DimensionResp> dimensionDescs) {
|
||||
when(dimensionService.getDimensions(any(DimensionFilter.class))).thenReturn(dimensionDescs);
|
||||
}
|
||||
|
||||
public static void metricDescBuild(MetricService metricService, List<MetricResp> metricDescs) {
|
||||
when(metricService.getMetrics(any(MetaFilter.class))).thenReturn(metricDescs);
|
||||
}
|
||||
|
||||
public static DimSchemaResp getDimensionDesc(Long id, String bizName, String name) {
|
||||
DimSchemaResp dimensionDesc = new DimSchemaResp();
|
||||
dimensionDesc.setId(id);
|
||||
dimensionDesc.setName(name);
|
||||
dimensionDesc.setBizName(bizName);
|
||||
return dimensionDesc;
|
||||
}
|
||||
|
||||
public static MetricSchemaResp getMetricDesc(Long id, String bizName, String name) {
|
||||
MetricSchemaResp dimensionDesc = new MetricSchemaResp();
|
||||
dimensionDesc.setId(id);
|
||||
dimensionDesc.setName(name);
|
||||
dimensionDesc.setBizName(bizName);
|
||||
return dimensionDesc;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ChatContextRepositoryImpl getChatContextRepository() {
|
||||
return Mockito.mock(ChatContextRepositoryImpl.class);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public QueryService getQueryService() {
|
||||
return Mockito.mock(QueryService.class);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public DimensionService getDimensionService() {
|
||||
return Mockito.mock(DimensionService.class);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MetricService getMetricService() {
|
||||
return Mockito.mock(MetricService.class);
|
||||
}
|
||||
|
||||
//queryDimensionDescs
|
||||
|
||||
@Bean
|
||||
public ModelService getModelService() {
|
||||
return Mockito.mock(ModelService.class);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ChatContextMapper getChatContextMapper() {
|
||||
return Mockito.mock(ChatContextMapper.class);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ConfigServiceImpl getModelExtendService() {
|
||||
return Mockito.mock(ConfigServiceImpl.class);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public RestTemplate restTemplate() {
|
||||
return new RestTemplate();
|
||||
}
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
package com.tencent.supersonic.chat.test.context;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import lombok.Data;
|
||||
|
||||
public class SemanticParseObjectHelper {
|
||||
|
||||
public static SemanticParseInfo copy(SemanticParseInfo semanticParseInfo) {
|
||||
Gson g = new Gson();
|
||||
return g.fromJson(g.toJson(semanticParseInfo), SemanticParseInfo.class);
|
||||
}
|
||||
|
||||
public static SemanticParseInfo getSemanticParseInfo(String json) {
|
||||
Gson gson = new Gson();
|
||||
SemanticParseJson semanticParseJson = gson.fromJson(json, SemanticParseJson.class);
|
||||
if (semanticParseJson != null) {
|
||||
return getSemanticParseInfo(semanticParseJson);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static SemanticParseInfo getSemanticParseInfo(SemanticParseJson semanticParseJson) {
|
||||
Long model = semanticParseJson.getModel();
|
||||
Set<SchemaElement> dimensionList = new LinkedHashSet();
|
||||
Set<SchemaElement> metricList = new LinkedHashSet();
|
||||
Set<QueryFilter> chatFilters = new LinkedHashSet();
|
||||
|
||||
if (semanticParseJson.getFilter() != null && semanticParseJson.getFilter().size() > 0) {
|
||||
for (List<String> filter : semanticParseJson.getFilter()) {
|
||||
chatFilters.add(getChatFilter(filter));
|
||||
}
|
||||
}
|
||||
|
||||
for (String dim : semanticParseJson.getDimensions()) {
|
||||
dimensionList.add(getDimension(dim, model));
|
||||
}
|
||||
for (String metric : semanticParseJson.getMetrics()) {
|
||||
metricList.add(getMetric(metric, model));
|
||||
}
|
||||
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
|
||||
semanticParseInfo.setDimensionFilters(chatFilters);
|
||||
semanticParseInfo.setAggType(semanticParseJson.getAggregateType());
|
||||
semanticParseInfo.setQueryMode(semanticParseJson.getQueryMode());
|
||||
semanticParseInfo.setMetrics(metricList);
|
||||
semanticParseInfo.setDimensions(dimensionList);
|
||||
|
||||
DateConf dateInfo = getDateInfoAgo(semanticParseJson.getDay());
|
||||
semanticParseInfo.setDateInfo(dateInfo);
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
private static DateConf getDateInfoAgo(int dayAgo) {
|
||||
if (dayAgo > 0) {
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setUnit(dayAgo);
|
||||
dateInfo.setDateMode(DateConf.DateMode.RECENT);
|
||||
return dateInfo;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static QueryFilter getChatFilter(List<String> filters) {
|
||||
if (filters.size() > 1) {
|
||||
QueryFilter chatFilter = new QueryFilter();
|
||||
|
||||
chatFilter.setBizName(filters.get(1));
|
||||
chatFilter.setOperator(FilterOperatorEnum.getSqlOperator(filters.get(2)));
|
||||
if (filters.size() > 4) {
|
||||
List<String> valuse = new ArrayList<>();
|
||||
valuse.addAll(filters.subList(3, filters.size()));
|
||||
chatFilter.setValue(valuse);
|
||||
} else {
|
||||
chatFilter.setValue(filters.get(3));
|
||||
}
|
||||
|
||||
return chatFilter;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static SchemaElement getMetric(String bizName, Long modelId) {
|
||||
SchemaElement metric = new SchemaElement();
|
||||
metric.setBizName(bizName);
|
||||
return metric;
|
||||
}
|
||||
|
||||
private static SchemaElement getDimension(String bizName, Long modelId) {
|
||||
SchemaElement dimension = new SchemaElement();
|
||||
dimension.setBizName(bizName);
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class SemanticParseJson {
|
||||
|
||||
private Long model;
|
||||
private String queryMode;
|
||||
private AggregateTypeEnum aggregateType;
|
||||
private Integer day;
|
||||
private List<String> dimensions;
|
||||
private List<String> metrics;
|
||||
private List<List<String>> filter;
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package com.tencent.supersonic.chat.utils;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateModeUtils;
|
||||
import com.tencent.supersonic.common.util.SqlFilterUtils;
|
||||
import com.tencent.supersonic.headless.api.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
/**
|
||||
* QueryReqBuilderTest
|
||||
*/
|
||||
class QueryReqBuilderTest {
|
||||
|
||||
@Test
|
||||
void buildS2SQLReq() {
|
||||
init();
|
||||
QueryStructReq queryStructReq = new QueryStructReq();
|
||||
queryStructReq.setModelId(1L);
|
||||
queryStructReq.setQueryType(QueryType.METRIC);
|
||||
queryStructReq.setModelName("内容库");
|
||||
|
||||
Aggregator aggregator = new Aggregator();
|
||||
aggregator.setFunc(AggOperatorEnum.UNKNOWN);
|
||||
aggregator.setColumn("pv");
|
||||
queryStructReq.setAggregators(Arrays.asList(aggregator));
|
||||
|
||||
queryStructReq.setGroups(Arrays.asList("department"));
|
||||
|
||||
DateConf dateConf = new DateConf();
|
||||
dateConf.setDateMode(DateMode.LIST);
|
||||
dateConf.setDateList(Arrays.asList("2023-08-01"));
|
||||
queryStructReq.setDateInfo(dateConf);
|
||||
|
||||
List<Order> orders = new ArrayList<>();
|
||||
Order order = new Order();
|
||||
order.setColumn("uv");
|
||||
orders.add(order);
|
||||
queryStructReq.setOrders(orders);
|
||||
|
||||
QueryS2SQLReq queryS2SQLReq = queryStructReq.convert(queryStructReq);
|
||||
Assert.assertEquals(
|
||||
"SELECT department, SUM(pv) FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
|
||||
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2SQLReq.getSql());
|
||||
|
||||
queryStructReq.setQueryType(QueryType.TAG);
|
||||
queryS2SQLReq = queryStructReq.convert(queryStructReq);
|
||||
Assert.assertEquals(
|
||||
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
|
||||
+ "ORDER BY uv LIMIT 2000",
|
||||
queryS2SQLReq.getSql());
|
||||
|
||||
}
|
||||
|
||||
private void init() {
|
||||
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
|
||||
SqlFilterUtils sqlFilterUtils = new SqlFilterUtils();
|
||||
mockContextUtils.when(() -> ContextUtils.getBean(SqlFilterUtils.class)).thenReturn(sqlFilterUtils);
|
||||
DateModeUtils dateModeUtils = new DateModeUtils();
|
||||
mockContextUtils.when(() -> ContextUtils.getBean(DateModeUtils.class)).thenReturn(dateModeUtils);
|
||||
dateModeUtils.setSysDateCol("sys_imp_date");
|
||||
dateModeUtils.setSysDateWeekCol("sys_imp_week");
|
||||
dateModeUtils.setSysDateMonthCol("sys_imp_month");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user