(improvement)(chat) Split chat into three modules: server, api, and core. (#594)

This commit is contained in:
lexluo09
2024-01-04 16:56:49 +08:00
committed by GitHub
parent 0858c13365
commit 023e84c420
337 changed files with 2407 additions and 2715 deletions

View File

@@ -1,23 +0,0 @@
package com.tencent.supersonic.chat.application.parser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.sql.rule.TimeRangeParser;
import org.junit.jupiter.api.Test;
class TimeRangeParserTest {
@Test
void parse() {
TimeRangeParser timeRangeParser = new TimeRangeParser();
QueryReq queryRequest = new QueryReq();
ChatContext chatCtx = new ChatContext();
queryRequest.setQueryText("supersonic最近30天访问次数");
timeRangeParser.parse(new QueryContext(queryRequest), chatCtx);
}
}

View File

@@ -1,38 +0,0 @@
package com.tencent.supersonic.chat.application.parser.aggregate;
import static org.junit.Assert.assertEquals;
import com.tencent.supersonic.chat.parser.sql.rule.AggregateTypeParser;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import org.junit.jupiter.api.Test;
class AggregateTypeParserTest {
@Test
void getAggregateParser() {
AggregateTypeParser aggregateParser = new AggregateTypeParser();
AggregateTypeEnum aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
assertEquals(aggregateType, AggregateTypeEnum.MAX);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品pv");
assertEquals(aggregateType, AggregateTypeEnum.COUNT);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品uv");
assertEquals(aggregateType, AggregateTypeEnum.DISTINCT);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
assertEquals(aggregateType, AggregateTypeEnum.MAX);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最小值");
assertEquals(aggregateType, AggregateTypeEnum.MIN);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数平均值");
assertEquals(aggregateType, AggregateTypeEnum.AVG);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数topN");
assertEquals(aggregateType, AggregateTypeEnum.TOPN);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数汇总");
assertEquals(aggregateType, AggregateTypeEnum.SUM);
}
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.application.search;
import org.junit.jupiter.api.Test;
class SearchServiceImplTest {
@Test
void search() {
}
@Test
void filerMetricsByModel() {
}
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class LoadRemoveServiceTest {
@Test
void edit() {
int compute = EditDistance.compute("", "在你的身边");
Assert.assertEquals(compute, 4);
}
}

View File

@@ -1,55 +0,0 @@
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class LLMResponseServiceTest {
@Test
void deduplicationSqlWeight() {
String sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp = new LLMResp();
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp.setSqlRespMap(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp2 = new LLMResp();
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp2.setSqlRespMap(sqlWeight2);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp3 = new LLMResp();
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp3.setSqlRespMap(sqlWeight3);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
Assert.assertEquals(deduplicationSqlResp.size(), 2);
}
}

View File

@@ -1,64 +0,0 @@
package com.tencent.supersonic.chat.parser.llm.s2sql;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class LLMS2SQLParserTest {
@Test
void setFilter() {
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
List<SchemaElement> dimensions = new ArrayList<>();
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
SchemaValueMap value1 = new SchemaValueMap();
value1.setBizName("杰伦");
value1.setTechName("周杰伦");
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
schemaValueMaps.add(value1);
SchemaElement schemaElement = SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.model(2L)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(schemaElement);
SchemaElement schemaElement2 = SchemaElement.builder()
.bizName("publish_time")
.name("发布时间")
.model(2L)
.build();
dimensions.add(schemaElement2);
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
List<SchemaElement> metrics = new ArrayList<>();
SchemaElement metric = SchemaElement.builder()
.bizName("play_count")
.name("播放量")
.model(2L)
.build();
metrics.add(metric);
when(mockSemanticSchema.getMetrics()).thenReturn(metrics);
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
}
}

View File

@@ -1,163 +0,0 @@
package com.tencent.supersonic.chat.processor;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.processor.parse.MetricCheckProcessor;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
class MetricCheckProcessorTest {
@Test
void testProcessCorrectSql_necessaryDimension_groupBy() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数) FROM 超音数 GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_necessaryDimension_where() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 where 部门 = 'HR' group by 用户名";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 "
+ "WHERE 部门 = 'HR' GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_groupBy() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 页面, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 部门";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_where() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 where 页面 = 'P1' group by 部门";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_necessaryDimension() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 页面, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT sum(访问次数) FROM 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionDrillDown() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 用户名, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名, 部门";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 用户名, 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 页面, 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 用户名";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "SELECT 页面, 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 页面, 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 访问次数 from 超音数";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 访问次数 from 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
/**
* 访问次数 drill down dimension is 用户名 and 部门
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
*/
private SemanticSchema mockModelSchema() {
ModelSchema modelSchema = new ModelSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC,
Lists.newArrayList(RelatedSchemaElement.builder().dimensionId(2L).isNecessary(false).build(),
RelatedSchemaElement.builder().dimensionId(1L).isNecessary(false).build())),
mockElement(2L, "访问用户数", SchemaElementType.METRIC,
Lists.newArrayList(RelatedSchemaElement.builder().dimensionId(2L).isNecessary(true).build()))
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return new SemanticSchema(Lists.newArrayList(modelSchema));
}
private SemanticSchema mockModelSchemaNoDimensionSetting() {
ModelSchema modelSchema = new ModelSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC, Lists.newArrayList()),
mockElement(2L, "访问用户数", SchemaElementType.METRIC, Lists.newArrayList())
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return new SemanticSchema(Lists.newArrayList(modelSchema));
}
private Set<SchemaElement> mockDimensions() {
return Sets.newHashSet(
mockElement(1L, "用户名", SchemaElementType.DIMENSION, Lists.newArrayList()),
mockElement(2L, "部门", SchemaElementType.DIMENSION, Lists.newArrayList()),
mockElement(3L, "页面", SchemaElementType.DIMENSION, Lists.newArrayList())
);
}
private SchemaElement mockElement(Long id, String name, SchemaElementType type,
List<RelatedSchemaElement> relateSchemaElements) {
return SchemaElement.builder().id(id).name(name).type(type)
.relatedSchemaElements(relateSchemaElements).build();
}
private SemanticParseInfo mockParseInfo(String correctSql) {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
return semanticParseInfo;
}
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.test;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.ComponentScan;
@SpringBootApplication(scanBasePackages = {"com.tencent.supersonic.chat"})
@ComponentScan("com.tencent.supersonic.chat")
@MapperScan("com.tencent.supersonic.chat")
public class ChatBizLauncher {
public static void main(String[] args) {
SpringApplication.run(ChatBizLauncher.class, args);
}
}

View File

@@ -1,151 +0,0 @@
package com.tencent.supersonic.chat.test.context;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.config.DefaultMetricInfo;
import com.tencent.supersonic.chat.config.EntityInternalDetail;
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.chat.persistence.repository.impl.ChatContextRepositoryImpl;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.impl.ConfigServiceImpl;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.response.DimensionResp;
import com.tencent.supersonic.headless.api.response.MetricResp;
import com.tencent.supersonic.headless.api.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelService;
import org.mockito.Mockito;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.when;
@Configuration
public class MockBeansConfiguration {
public static void getOrCreateContextMock(ChatService chatService) {
ChatContext context = new ChatContext();
context.setChatId(1);
when(chatService.getOrCreateContext(1)).thenReturn(context);
}
public static void buildHttpSemanticServiceImpl(List<DimSchemaResp> dimensionDescs,
List<MetricSchemaResp> metricDescs) {
DefaultMetric defaultMetricDesc = new DefaultMetric();
defaultMetricDesc.setUnit(3);
defaultMetricDesc.setPeriod(Constants.DAY);
List<DimSchemaResp> dimensionDescs1 = new ArrayList<>();
DimSchemaResp dimensionDesc = new DimSchemaResp();
dimensionDesc.setId(162L);
dimensionDescs1.add(dimensionDesc);
DimSchemaResp dimensionDesc2 = new DimSchemaResp();
dimensionDesc2.setId(163L);
dimensionDesc2.setBizName("song_name");
dimensionDesc2.setName("歌曲名");
EntityInternalDetail entityInternalDetailDesc = new EntityInternalDetail();
entityInternalDetailDesc.setDimensionList(new ArrayList<>(Arrays.asList(dimensionDesc2)));
MetricSchemaResp metricDesc = new MetricSchemaResp();
metricDesc.setId(877L);
metricDesc.setBizName("js_play_cnt");
metricDesc.setName("结算播放量");
entityInternalDetailDesc.setMetricList(new ArrayList<>(Arrays.asList(metricDesc)));
ModelSchemaResp modelSchemaDesc = new ModelSchemaResp();
modelSchemaDesc.setDimensions(dimensionDescs);
modelSchemaDesc.setMetrics(metricDescs);
}
public static void getModelExtendMock(ConfigServiceImpl configService) {
DefaultMetricInfo defaultMetricInfo = new DefaultMetricInfo();
defaultMetricInfo.setUnit(3);
defaultMetricInfo.setPeriod(Constants.DAY);
List<DefaultMetricInfo> defaultMetricInfos = new ArrayList<>();
defaultMetricInfos.add(defaultMetricInfo);
ChatConfigResp chaConfigDesc = new ChatConfigResp();
when(configService.fetchConfigByModelId(anyLong())).thenReturn(chaConfigDesc);
}
public static void dimensionDescBuild(DimensionService dimensionService, List<DimensionResp> dimensionDescs) {
when(dimensionService.getDimensions(any(DimensionFilter.class))).thenReturn(dimensionDescs);
}
public static void metricDescBuild(MetricService metricService, List<MetricResp> metricDescs) {
when(metricService.getMetrics(any(MetaFilter.class))).thenReturn(metricDescs);
}
public static DimSchemaResp getDimensionDesc(Long id, String bizName, String name) {
DimSchemaResp dimensionDesc = new DimSchemaResp();
dimensionDesc.setId(id);
dimensionDesc.setName(name);
dimensionDesc.setBizName(bizName);
return dimensionDesc;
}
public static MetricSchemaResp getMetricDesc(Long id, String bizName, String name) {
MetricSchemaResp dimensionDesc = new MetricSchemaResp();
dimensionDesc.setId(id);
dimensionDesc.setName(name);
dimensionDesc.setBizName(bizName);
return dimensionDesc;
}
@Bean
public ChatContextRepositoryImpl getChatContextRepository() {
return Mockito.mock(ChatContextRepositoryImpl.class);
}
@Bean
public QueryService getQueryService() {
return Mockito.mock(QueryService.class);
}
@Bean
public DimensionService getDimensionService() {
return Mockito.mock(DimensionService.class);
}
@Bean
public MetricService getMetricService() {
return Mockito.mock(MetricService.class);
}
//queryDimensionDescs
@Bean
public ModelService getModelService() {
return Mockito.mock(ModelService.class);
}
@Bean
public ChatContextMapper getChatContextMapper() {
return Mockito.mock(ChatContextMapper.class);
}
@Bean
public ConfigServiceImpl getModelExtendService() {
return Mockito.mock(ConfigServiceImpl.class);
}
@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
}
}

View File

@@ -1,118 +0,0 @@
package com.tencent.supersonic.chat.test.context;
import com.google.gson.Gson;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.DateConf;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import lombok.Data;
public class SemanticParseObjectHelper {
public static SemanticParseInfo copy(SemanticParseInfo semanticParseInfo) {
Gson g = new Gson();
return g.fromJson(g.toJson(semanticParseInfo), SemanticParseInfo.class);
}
public static SemanticParseInfo getSemanticParseInfo(String json) {
Gson gson = new Gson();
SemanticParseJson semanticParseJson = gson.fromJson(json, SemanticParseJson.class);
if (semanticParseJson != null) {
return getSemanticParseInfo(semanticParseJson);
}
return null;
}
private static SemanticParseInfo getSemanticParseInfo(SemanticParseJson semanticParseJson) {
Long model = semanticParseJson.getModel();
Set<SchemaElement> dimensionList = new LinkedHashSet();
Set<SchemaElement> metricList = new LinkedHashSet();
Set<QueryFilter> chatFilters = new LinkedHashSet();
if (semanticParseJson.getFilter() != null && semanticParseJson.getFilter().size() > 0) {
for (List<String> filter : semanticParseJson.getFilter()) {
chatFilters.add(getChatFilter(filter));
}
}
for (String dim : semanticParseJson.getDimensions()) {
dimensionList.add(getDimension(dim, model));
}
for (String metric : semanticParseJson.getMetrics()) {
metricList.add(getMetric(metric, model));
}
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setDimensionFilters(chatFilters);
semanticParseInfo.setAggType(semanticParseJson.getAggregateType());
semanticParseInfo.setQueryMode(semanticParseJson.getQueryMode());
semanticParseInfo.setMetrics(metricList);
semanticParseInfo.setDimensions(dimensionList);
DateConf dateInfo = getDateInfoAgo(semanticParseJson.getDay());
semanticParseInfo.setDateInfo(dateInfo);
return semanticParseInfo;
}
private static DateConf getDateInfoAgo(int dayAgo) {
if (dayAgo > 0) {
DateConf dateInfo = new DateConf();
dateInfo.setUnit(dayAgo);
dateInfo.setDateMode(DateConf.DateMode.RECENT);
return dateInfo;
}
return null;
}
private static QueryFilter getChatFilter(List<String> filters) {
if (filters.size() > 1) {
QueryFilter chatFilter = new QueryFilter();
chatFilter.setBizName(filters.get(1));
chatFilter.setOperator(FilterOperatorEnum.getSqlOperator(filters.get(2)));
if (filters.size() > 4) {
List<String> valuse = new ArrayList<>();
valuse.addAll(filters.subList(3, filters.size()));
chatFilter.setValue(valuse);
} else {
chatFilter.setValue(filters.get(3));
}
return chatFilter;
}
return null;
}
private static SchemaElement getMetric(String bizName, Long modelId) {
SchemaElement metric = new SchemaElement();
metric.setBizName(bizName);
return metric;
}
private static SchemaElement getDimension(String bizName, Long modelId) {
SchemaElement dimension = new SchemaElement();
dimension.setBizName(bizName);
return dimension;
}
@Data
public static class SemanticParseJson {
private Long model;
private String queryMode;
private AggregateTypeEnum aggregateType;
private Integer day;
private List<String> dimensions;
private List<String> metrics;
private List<List<String>> filter;
}
}

View File

@@ -1,78 +0,0 @@
package com.tencent.supersonic.chat.utils;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateModeUtils;
import com.tencent.supersonic.common.util.SqlFilterUtils;
import com.tencent.supersonic.headless.api.request.QueryS2SQLReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
/**
* QueryReqBuilderTest
*/
class QueryReqBuilderTest {
@Test
void buildS2SQLReq() {
init();
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setModelId(1L);
queryStructReq.setQueryType(QueryType.METRIC);
queryStructReq.setModelName("内容库");
Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.UNKNOWN);
aggregator.setColumn("pv");
queryStructReq.setAggregators(Arrays.asList(aggregator));
queryStructReq.setGroups(Arrays.asList("department"));
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateMode.LIST);
dateConf.setDateList(Arrays.asList("2023-08-01"));
queryStructReq.setDateInfo(dateConf);
List<Order> orders = new ArrayList<>();
Order order = new Order();
order.setColumn("uv");
orders.add(order);
queryStructReq.setOrders(orders);
QueryS2SQLReq queryS2SQLReq = queryStructReq.convert(queryStructReq);
Assert.assertEquals(
"SELECT department, SUM(pv) FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2SQLReq.getSql());
queryStructReq.setQueryType(QueryType.TAG);
queryS2SQLReq = queryStructReq.convert(queryStructReq);
Assert.assertEquals(
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "ORDER BY uv LIMIT 2000",
queryS2SQLReq.getSql());
}
private void init() {
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
SqlFilterUtils sqlFilterUtils = new SqlFilterUtils();
mockContextUtils.when(() -> ContextUtils.getBean(SqlFilterUtils.class)).thenReturn(sqlFilterUtils);
DateModeUtils dateModeUtils = new DateModeUtils();
mockContextUtils.when(() -> ContextUtils.getBean(DateModeUtils.class)).thenReturn(dateModeUtils);
dateModeUtils.setSysDateCol("sys_imp_date");
dateModeUtils.setSysDateWeekCol("sys_imp_week");
dateModeUtils.setSysDateMonthCol("sys_imp_month");
}
}