mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(improvement)(headless)Introduce headless-chat. #1155
This commit is contained in:
@@ -1,75 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.AggCorrector;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
class AggCorrectorTest {
|
||||
|
||||
@Test
|
||||
void testDoCorrect() {
|
||||
AggCorrector corrector = new AggCorrector();
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SchemaElement dataSet = new SchemaElement();
|
||||
dataSet.setDataSet(dataSetId);
|
||||
semanticParseInfo.setDataSet(dataSet);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
|
||||
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
|
||||
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(queryContext, semanticParseInfo);
|
||||
assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
|
||||
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
|
||||
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long dataSetId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSet(1L);
|
||||
element1.setName("部门");
|
||||
dimensions.add(element1);
|
||||
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
SchemaElement metric1 = new SchemaElement();
|
||||
metric1.setDataSet(1L);
|
||||
metric1.setName("访问次数");
|
||||
metrics.add(metric1);
|
||||
|
||||
dataSetSchema.setMetrics(metrics);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.ParseResult;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
class SchemaCorrectorTest {
|
||||
|
||||
private String json = "{\n"
|
||||
+ " \"dataSetId\": 1,\n"
|
||||
+ " \"llmReq\": {\n"
|
||||
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
|
||||
+ " \"filterCondition\": {\n"
|
||||
+ " \"tableName\": null\n"
|
||||
+ " },\n"
|
||||
+ " \"schema\": {\n"
|
||||
+ " \"domainName\": \"歌曲\",\n"
|
||||
+ " \"dataSetName\": \"歌曲\",\n"
|
||||
+ " \"fieldNameList\": [\n"
|
||||
+ " \"商务组\",\n"
|
||||
+ " \"歌曲名\",\n"
|
||||
+ " \"播放量\",\n"
|
||||
+ " \"播放份额\",\n"
|
||||
+ " \"数据日期\"\n"
|
||||
+ " ]\n"
|
||||
+ " },\n"
|
||||
+ " \"linking\": [\n"
|
||||
+ "\n"
|
||||
+ " ],\n"
|
||||
+ " \"currentDate\": \"2024-02-24\",\n"
|
||||
+ " \"priorExts\": \"播放份额是小数; \",\n"
|
||||
+ " \"sqlGenerationMode\": \"2_pass_auto_cot\"\n"
|
||||
+ " },\n"
|
||||
+ " \"request\": null,\n"
|
||||
+ " \"commonAgentTool\": {\n"
|
||||
+ " \"id\": \"y3LqVSRL\",\n"
|
||||
+ " \"name\": \"大模型语义解析\",\n"
|
||||
+ " \"type\": \"NL2SQL_LLM\",\n"
|
||||
+ " \"dataSetIds\": [\n"
|
||||
+ " 1\n"
|
||||
+ " ]\n"
|
||||
+ " },\n"
|
||||
+ " \"linkingValues\": [\n"
|
||||
+ "\n"
|
||||
+ " ]\n"
|
||||
+ "}";
|
||||
|
||||
@Test
|
||||
void doCorrect() throws JsonProcessingException {
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
|
||||
String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
|
||||
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
semanticParseInfo.setDataSet(schemaElement);
|
||||
|
||||
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||
|
||||
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
|
||||
parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
|
||||
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
|
||||
elementValue.setFieldName("商务组");
|
||||
elementValue.setFieldValue("xxx");
|
||||
linkingValues.add(elementValue);
|
||||
parseResult.setLinkingValues(linkingValues);
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setS2SQL(sql);
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long dataSetId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSet(1L);
|
||||
element1.setName("歌曲名");
|
||||
dimensions.add(element1);
|
||||
|
||||
SchemaElement element2 = new SchemaElement();
|
||||
element2.setDataSet(1L);
|
||||
element2.setName("商务组");
|
||||
dimensions.add(element2);
|
||||
|
||||
SchemaElement element3 = new SchemaElement();
|
||||
element3.setDataSet(1L);
|
||||
element3.setName("发行日期");
|
||||
dimensions.add(element3);
|
||||
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.TimeCorrector;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testng.Assert;
|
||||
|
||||
class TimeCorrectorTest {
|
||||
|
||||
@Test
|
||||
void testDoCorrect() {
|
||||
TimeCorrector corrector = new TimeCorrector();
|
||||
QueryContext queryContext = new QueryContext();
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
//1.数据日期 <=
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//2.数据日期 <
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//3.数据日期 >=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//4.数据日期 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//5.no 数据日期
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//6. 数据日期-月 <=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
|
||||
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//7. 数据日期-月 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//8. no where
|
||||
sql = "SELECT COUNT(1) FROM 数据库";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class LoadRemoveServiceTest {
|
||||
|
||||
@Test
|
||||
void edit() {
|
||||
int compute = EditDistance.compute("在", "在你的身边");
|
||||
Assertions.assertEquals(compute, 4);
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.aggregate;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.rule.AggregateTypeParser;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testng.Assert;
|
||||
|
||||
class AggregateTypeParserTest {
|
||||
|
||||
@Test
|
||||
void getAggregateParser() {
|
||||
AggregateTypeParser aggregateParser = new AggregateTypeParser();
|
||||
AggregateTypeEnum aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.MAX);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品pv");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.COUNT);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品uv");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.DISTINCT);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.MAX);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最小值");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.MIN);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数平均值");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.AVG);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数topN");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.TOPN);
|
||||
|
||||
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数汇总");
|
||||
Assert.assertEquals(aggregateType, AggregateTypeEnum.SUM);
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.s2sql;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.LLMResponseService;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testng.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
class LLMResponseServiceTest {
|
||||
|
||||
@Test
|
||||
void deduplicationSqlWeight() {
|
||||
String sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
|
||||
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
llmResp.setSqlRespMap(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp2 = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
|
||||
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
llmResp2.setSqlRespMap(sqlWeight2);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp3 = new LLMResp();
|
||||
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
|
||||
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
llmResp3.setSqlRespMap(sqlWeight3);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.s2sql;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class LLMSqlParserTest {
|
||||
|
||||
@Test
|
||||
void setFilter() {
|
||||
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
|
||||
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
|
||||
SchemaValueMap value1 = new SchemaValueMap();
|
||||
value1.setBizName("杰伦");
|
||||
value1.setTechName("周杰伦");
|
||||
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
|
||||
schemaValueMaps.add(value1);
|
||||
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.dataSet(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
|
||||
SchemaElement schemaElement2 = SchemaElement.builder()
|
||||
.bizName("publish_time")
|
||||
.name("发布时间")
|
||||
.dataSet(2L)
|
||||
.build();
|
||||
dimensions.add(schemaElement2);
|
||||
|
||||
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
|
||||
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
SchemaElement metric = SchemaElement.builder()
|
||||
.bizName("play_count")
|
||||
.name("播放量")
|
||||
.dataSet(2L)
|
||||
.build();
|
||||
metrics.add(metric);
|
||||
|
||||
when(mockSemanticSchema.getMetrics()).thenReturn(metrics);
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
class S2SqlDateHelperTest {
|
||||
|
||||
@Test
|
||||
void getReferenceDate() {
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
|
||||
String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
||||
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
||||
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
||||
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
timeDefaultConfig.setUnit(20);
|
||||
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20));
|
||||
|
||||
timeDefaultConfig.setUnit(1);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1));
|
||||
|
||||
timeDefaultConfig.setUnit(-1);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertNull(referenceDate);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getStartEndDate() {
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
|
||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.DETAIL);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
|
||||
Assert.assertNull(startEndDate.getLeft());
|
||||
Assert.assertNull(startEndDate.getRight());
|
||||
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
||||
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
timeDefaultConfig.setUnit(20);
|
||||
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
||||
|
||||
timeDefaultConfig.setUnit(2);
|
||||
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
||||
|
||||
timeDefaultConfig.setUnit(-1);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertNull(startEndDate.getLeft());
|
||||
Assert.assertNull(startEndDate.getRight());
|
||||
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
timeDefaultConfig.setUnit(5);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5));
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long dataSetId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user