mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(improvement)(Chat) Move chat-core to headless (#805)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -0,0 +1,145 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.ParseResult;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
class SchemaCorrectorTest {
|
||||
|
||||
private String json = "{\n"
|
||||
+ " \"dataSetId\": 1,\n"
|
||||
+ " \"llmReq\": {\n"
|
||||
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
|
||||
+ " \"filterCondition\": {\n"
|
||||
+ " \"tableName\": null\n"
|
||||
+ " },\n"
|
||||
+ " \"schema\": {\n"
|
||||
+ " \"domainName\": \"歌曲\",\n"
|
||||
+ " \"dataSetName\": \"歌曲\",\n"
|
||||
+ " \"fieldNameList\": [\n"
|
||||
+ " \"商务组\",\n"
|
||||
+ " \"歌曲名\",\n"
|
||||
+ " \"播放量\",\n"
|
||||
+ " \"播放份额\",\n"
|
||||
+ " \"数据日期\"\n"
|
||||
+ " ]\n"
|
||||
+ " },\n"
|
||||
+ " \"linking\": [\n"
|
||||
+ "\n"
|
||||
+ " ],\n"
|
||||
+ " \"currentDate\": \"2024-02-24\",\n"
|
||||
+ " \"priorExts\": \"播放份额是小数; \",\n"
|
||||
+ " \"sqlGenerationMode\": \"2_pass_auto_cot\"\n"
|
||||
+ " },\n"
|
||||
+ " \"request\": null,\n"
|
||||
+ " \"commonAgentTool\": {\n"
|
||||
+ " \"id\": \"y3LqVSRL\",\n"
|
||||
+ " \"name\": \"大模型语义解析\",\n"
|
||||
+ " \"type\": \"NL2SQL_LLM\",\n"
|
||||
+ " \"dataSetIds\": [\n"
|
||||
+ " 1\n"
|
||||
+ " ]\n"
|
||||
+ " },\n"
|
||||
+ " \"linkingValues\": [\n"
|
||||
+ "\n"
|
||||
+ " ]\n"
|
||||
+ "}";
|
||||
|
||||
@Test
|
||||
void doCorrect() throws JsonProcessingException {
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
|
||||
String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
|
||||
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
semanticParseInfo.setDataSet(schemaElement);
|
||||
|
||||
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||
|
||||
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
|
||||
parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
|
||||
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
|
||||
elementValue.setFieldName("商务组");
|
||||
elementValue.setFieldValue("xxx");
|
||||
linkingValues.add(elementValue);
|
||||
parseResult.setLinkingValues(linkingValues);
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setS2SQL(sql);
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long dataSetId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSet(1L);
|
||||
element1.setName("歌曲名");
|
||||
dimensions.add(element1);
|
||||
|
||||
SchemaElement element2 = new SchemaElement();
|
||||
element2.setDataSet(1L);
|
||||
element2.setName("商务组");
|
||||
dimensions.add(element2);
|
||||
|
||||
SchemaElement element3 = new SchemaElement();
|
||||
element3.setDataSet(1L);
|
||||
element3.setName("发行日期");
|
||||
dimensions.add(element3);
|
||||
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.TimeCorrector;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testng.Assert;
|
||||
|
||||
class TimeCorrectorTest {
|
||||
|
||||
@Test
|
||||
void testDoCorrect() {
|
||||
TimeCorrector corrector = new TimeCorrector();
|
||||
QueryContext queryContext = new QueryContext();
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
//1.数据日期 <=
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//2.数据日期 <
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//3.数据日期 >=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//4.数据日期 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//5.no 数据日期
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//6. 数据日期-月 <=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
|
||||
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//7. 数据日期-月 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//8. no where
|
||||
sql = "SELECT COUNT(1) FROM 数据库";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class LoadRemoveServiceTest {
|
||||
|
||||
@Test
|
||||
void edit() {
|
||||
int compute = EditDistance.compute("在", "在你的身边");
|
||||
Assertions.assertEquals(compute, 4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package com.tencent.supersonic.chat.core.parser.aggregate;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.rule.AggregateTypeParser;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testng.Assert;
|
||||
|
||||
class AggregateTypeParserTest {
|
||||
|
||||
@Test
|
||||
void getAggregateParser() {
|
||||
AggregateTypeParser aggregateParser = new AggregateTypeParser();
|
||||
AggregateTypeEnum aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.MAX);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品pv");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.COUNT);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品uv");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.DISTINCT);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.MAX);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最小值");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.MIN);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数平均值");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.AVG);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数topN");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.TOPN);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数汇总");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.SUM);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.tencent.supersonic.chat.core.s2sql;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.LLMResponseService;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testng.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
class LLMResponseServiceTest {
|
||||
|
||||
@Test
|
||||
void deduplicationSqlWeight() {
|
||||
String sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
|
||||
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
llmResp.setSqlRespMap(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp2 = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
|
||||
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
llmResp2.setSqlRespMap(sqlWeight2);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp3 = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
|
||||
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
llmResp3.setSqlRespMap(sqlWeight3);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.tencent.supersonic.chat.core.s2sql;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class LLMSqlParserTest {
|
||||
|
||||
@Test
|
||||
void setFilter() {
|
||||
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
|
||||
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
|
||||
SchemaValueMap value1 = new SchemaValueMap();
|
||||
value1.setBizName("杰伦");
|
||||
value1.setTechName("周杰伦");
|
||||
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
|
||||
schemaValueMaps.add(value1);
|
||||
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.dataSet(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
|
||||
SchemaElement schemaElement2 = SchemaElement.builder()
|
||||
.bizName("publish_time")
|
||||
.name("发布时间")
|
||||
.dataSet(2L)
|
||||
.build();
|
||||
dimensions.add(schemaElement2);
|
||||
|
||||
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
|
||||
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
SchemaElement metric = SchemaElement.builder()
|
||||
.bizName("play_count")
|
||||
.name("播放量")
|
||||
.dataSet(2L)
|
||||
.build();
|
||||
metrics.add(metric);
|
||||
|
||||
when(mockSemanticSchema.getMetrics()).thenReturn(metrics);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
class S2SqlDateHelperTest {
|
||||
|
||||
@Test
|
||||
void getReferenceDate() {
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
|
||||
String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
||||
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
||||
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
||||
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
timeDefaultConfig.setUnit(20);
|
||||
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20));
|
||||
|
||||
timeDefaultConfig.setUnit(1);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1));
|
||||
|
||||
timeDefaultConfig.setUnit(-1);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertNull(referenceDate);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getStartEndDate() {
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
|
||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.TAG);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
|
||||
Assert.assertNull(startEndDate.getLeft());
|
||||
Assert.assertNull(startEndDate.getRight());
|
||||
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
||||
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
timeDefaultConfig.setUnit(20);
|
||||
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
||||
|
||||
timeDefaultConfig.setUnit(2);
|
||||
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
||||
|
||||
timeDefaultConfig.setUnit(-1);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertNull(startEndDate.getLeft());
|
||||
Assert.assertNull(startEndDate.getRight());
|
||||
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
timeDefaultConfig.setUnit(5);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5));
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long dataSetId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
}
|
||||
17
headless/core/src/test/resources/application.yaml
Normal file
17
headless/core/src/test/resources/application.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
mybatis:
|
||||
mapper-locations: classpath:mapper/*.xml
|
||||
|
||||
spring:
|
||||
h2:
|
||||
console:
|
||||
path: /h2-console/semantic
|
||||
# enabled web
|
||||
enabled: true
|
||||
datasource:
|
||||
driver-class-name: org.h2.Driver
|
||||
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false
|
||||
username: root
|
||||
password: semantic
|
||||
schema: classpath:db/chat-schema-h2.sql
|
||||
data: classpath:db/chat-data-h2.sql
|
||||
|
||||
2
headless/core/src/test/resources/data/README.url
Normal file
2
headless/core/src/test/resources/data/README.url
Normal file
@@ -0,0 +1,2 @@
|
||||
[InternetShortcut]
|
||||
URL=https://github.com/hankcs/HanLP/
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,6 @@
|
||||
tom _1_2 5
|
||||
alice _1_2 6
|
||||
lucy _1_2 4
|
||||
dean _1_2 2
|
||||
john _1_2 8
|
||||
jack _1_2 8
|
||||
1
headless/core/src/test/resources/data/version.txt
Normal file
1
headless/core/src/test/resources/data/version.txt
Normal file
@@ -0,0 +1 @@
|
||||
1.7.5
|
||||
1
headless/core/src/test/resources/db/chat-data-h2.sql
Normal file
1
headless/core/src/test/resources/db/chat-data-h2.sql
Normal file
@@ -0,0 +1 @@
|
||||
insert into chat_context (chat_id, modified_at , `user`, `query_text`, `semantic_parse` ,ext_data) VALUES(1, '2023-05-24 00:00:00', 'admin', '超音数访问次数', '', 'admin');
|
||||
59
headless/core/src/test/resources/db/chat-schema-h2.sql
Normal file
59
headless/core/src/test/resources/db/chat-schema-h2.sql
Normal file
@@ -0,0 +1,59 @@
|
||||
CREATE TABLE `chat_context`
|
||||
(
|
||||
`chat_id` BIGINT NOT NULL , -- context chat id
|
||||
`modified_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP , -- row modify time
|
||||
`user` varchar(64) DEFAULT NULL , -- row modify user
|
||||
`query_text` LONGVARCHAR DEFAULT NULL , -- query text
|
||||
`semantic_parse` LONGVARCHAR DEFAULT NULL , -- parse data
|
||||
`ext_data` LONGVARCHAR DEFAULT NULL , -- extend data
|
||||
PRIMARY KEY (`chat_id`)
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE `chat`
|
||||
(
|
||||
`chat_id` BIGINT NOT NULL ,-- AUTO_INCREMENT,
|
||||
`chat_name` varchar(100) DEFAULT NULL,
|
||||
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
|
||||
`last_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
|
||||
`creator` varchar(30) DEFAULT NULL,
|
||||
`last_question` varchar(200) DEFAULT NULL,
|
||||
`is_delete` INT DEFAULT '0' COMMENT 'is deleted',
|
||||
`is_top` INT DEFAULT '0' COMMENT 'is top',
|
||||
PRIMARY KEY (`chat_id`)
|
||||
) ;
|
||||
|
||||
CREATE TABLE `chat_query`
|
||||
(
|
||||
`id` BIGINT NOT NULL ,--AUTO_INCREMENT,
|
||||
`question_id` BIGINT DEFAULT NULL,
|
||||
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
`user_name` varchar(150) DEFAULT NULL COMMENT '',
|
||||
`question` varchar(300) DEFAULT NULL COMMENT '',
|
||||
`query_result` LONGVARCHAR,
|
||||
`time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
|
||||
`state` int(1) DEFAULT NULL,
|
||||
`data_content` varchar(30) DEFAULT NULL,
|
||||
`name` varchar(100) DEFAULT NULL,
|
||||
`scene_type` int(2) DEFAULT NULL,
|
||||
`query_type` int(2) DEFAULT NULL,
|
||||
`is_deleted` int(1) DEFAULT NULL,
|
||||
`module` varchar(30) DEFAULT NULL,
|
||||
`entity` LONGVARCHAR COMMENT '',
|
||||
`chat_id` BIGINT DEFAULT NULL COMMENT 'chat id',
|
||||
`recommend` text,
|
||||
`aggregator` varchar(20) DEFAULT 'trend',
|
||||
`top_num` int DEFAULT NULL,
|
||||
`start_time` varchar(30) DEFAULT NULL,
|
||||
`end_time` varchar(30) DEFAULT NULL,
|
||||
`compare_recommend` LONGVARCHAR,
|
||||
`compare_entity` LONGVARCHAR,
|
||||
`query_sql` LONGVARCHAR,
|
||||
`columns` varchar(2000) DEFAULT NULL,
|
||||
`result_list` LONGVARCHAR,
|
||||
`main_entity` varchar(5000) DEFAULT NULL,
|
||||
`semantic_text` varchar(5000) DEFAULT NULL,
|
||||
`score` int DEFAULT '0',
|
||||
`feedback` varchar(1024) DEFAULT '',
|
||||
PRIMARY KEY (`id`)
|
||||
) ;
|
||||
1
headless/core/src/test/resources/hanlp.properties
Normal file
1
headless/core/src/test/resources/hanlp.properties
Normal file
@@ -0,0 +1 @@
|
||||
CustomDictionaryPath=data/dictionary/custom/DimValue_1_2.txt
|
||||
Reference in New Issue
Block a user