(improvement)(headless)Introduce headless-chat. #1155

This commit is contained in:
jerryjzhang
2024-06-15 18:56:39 +08:00
parent f2a12e56b7
commit c3ecc05715
158 changed files with 771 additions and 692 deletions

View File

@@ -1,75 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.core.chat.corrector.AggCorrector;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static org.testng.Assert.assertEquals;
class AggCorrectorTest {
@Test
void testDoCorrect() {
AggCorrector corrector = new AggCorrector();
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SchemaElement dataSet = new SchemaElement();
dataSet.setDataSet(dataSetId);
semanticParseInfo.setDataSet(dataSet);
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(queryContext, semanticParseInfo);
assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
}
private QueryContext buildQueryContext(Long dataSetId) {
QueryContext queryContext = new QueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
QueryConfig queryConfig = new QueryConfig();
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSet(1L);
element1.setName("部门");
dimensions.add(element1);
dataSetSchema.setDimensions(dimensions);
Set<SchemaElement> metrics = new HashSet<>();
SchemaElement metric1 = new SchemaElement();
metric1.setDataSet(1L);
metric1.setName("访问次数");
metrics.add(metric1);
dataSetSchema.setMetrics(metrics);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}
}

View File

@@ -1,145 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector;
import com.tencent.supersonic.headless.core.chat.parser.llm.ParseResult;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static org.testng.Assert.assertEquals;
class SchemaCorrectorTest {
private String json = "{\n"
+ " \"dataSetId\": 1,\n"
+ " \"llmReq\": {\n"
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
+ " \"filterCondition\": {\n"
+ " \"tableName\": null\n"
+ " },\n"
+ " \"schema\": {\n"
+ " \"domainName\": \"歌曲\",\n"
+ " \"dataSetName\": \"歌曲\",\n"
+ " \"fieldNameList\": [\n"
+ " \"商务组\",\n"
+ " \"歌曲名\",\n"
+ " \"播放量\",\n"
+ " \"播放份额\",\n"
+ " \"数据日期\"\n"
+ " ]\n"
+ " },\n"
+ " \"linking\": [\n"
+ "\n"
+ " ],\n"
+ " \"currentDate\": \"2024-02-24\",\n"
+ " \"priorExts\": \"播放份额是小数; \",\n"
+ " \"sqlGenerationMode\": \"2_pass_auto_cot\"\n"
+ " },\n"
+ " \"request\": null,\n"
+ " \"commonAgentTool\": {\n"
+ " \"id\": \"y3LqVSRL\",\n"
+ " \"name\": \"大模型语义解析\",\n"
+ " \"type\": \"NL2SQL_LLM\",\n"
+ " \"dataSetIds\": [\n"
+ " 1\n"
+ " ]\n"
+ " },\n"
+ " \"linkingValues\": [\n"
+ "\n"
+ " ]\n"
+ "}";
@Test
void doCorrect() throws JsonProcessingException {
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
ObjectMapper objectMapper = new ObjectMapper();
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
semanticParseInfo.setDataSet(schemaElement);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
SchemaCorrector schemaCorrector = new SchemaCorrector();
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
parseResult = objectMapper.readValue(json, ParseResult.class);
List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
elementValue.setFieldName("商务组");
elementValue.setFieldValue("xxx");
linkingValues.add(elementValue);
parseResult.setLinkingValues(linkingValues);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
semanticParseInfo.getSqlInfo().setS2SQL(sql);
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
}
private QueryContext buildQueryContext(Long dataSetId) {
QueryContext queryContext = new QueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
QueryConfig queryConfig = new QueryConfig();
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSet(1L);
element1.setName("歌曲名");
dimensions.add(element1);
SchemaElement element2 = new SchemaElement();
element2.setDataSet(1L);
element2.setName("商务组");
dimensions.add(element2);
SchemaElement element3 = new SchemaElement();
element3.setDataSet(1L);
element3.setName("发行日期");
dimensions.add(element3);
dataSetSchema.setDimensions(dimensions);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}
}

View File

@@ -1,101 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.core.chat.corrector.TimeCorrector;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import org.junit.jupiter.api.Test;
import org.testng.Assert;
class TimeCorrectorTest {
@Test
void testDoCorrect() {
TimeCorrector corrector = new TimeCorrector();
QueryContext queryContext = new QueryContext();
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
//1.数据日期 <=
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//2.数据日期 <
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//3.数据日期 >=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//4.数据日期 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//5.no 数据日期
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//6. 数据日期-月 <=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//7. 数据日期-月 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//8. no where
sql = "SELECT COUNT(1) FROM 数据库";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
}
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class LoadRemoveServiceTest {
@Test
void edit() {
int compute = EditDistance.compute("", "在你的身边");
Assertions.assertEquals(compute, 4);
}
}

View File

@@ -1,37 +0,0 @@
package com.tencent.supersonic.chat.core.parser.aggregate;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.headless.core.chat.parser.rule.AggregateTypeParser;
import org.junit.jupiter.api.Test;
import org.testng.Assert;
class AggregateTypeParserTest {
@Test
void getAggregateParser() {
AggregateTypeParser aggregateParser = new AggregateTypeParser();
AggregateTypeEnum aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
Assert.assertEquals(aggregateType, AggregateTypeEnum.MAX);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品pv");
Assert.assertEquals(aggregateType, AggregateTypeEnum.COUNT);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品uv");
Assert.assertEquals(aggregateType, AggregateTypeEnum.DISTINCT);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最大值");
Assert.assertEquals(aggregateType, AggregateTypeEnum.MAX);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数最小值");
Assert.assertEquals(aggregateType, AggregateTypeEnum.MIN);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数平均值");
Assert.assertEquals(aggregateType, AggregateTypeEnum.AVG);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数topN");
Assert.assertEquals(aggregateType, AggregateTypeEnum.TOPN);
aggregateType = aggregateParser.resolveAggregateType("supsersonic产品访问次数汇总");
Assert.assertEquals(aggregateType, AggregateTypeEnum.SUM);
}
}

View File

@@ -1,56 +0,0 @@
package com.tencent.supersonic.chat.core.s2sql;
import com.tencent.supersonic.headless.core.chat.parser.llm.LLMResponseService;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
import org.junit.jupiter.api.Test;
import org.testng.Assert;
import java.util.HashMap;
import java.util.Map;
class LLMResponseServiceTest {
@Test
void deduplicationSqlWeight() {
String sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp = new LLMResp();
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp.setSqlRespMap(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp2 = new LLMResp();
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp2.setSqlRespMap(sqlWeight2);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp3 = new LLMResp();
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp3.setSqlRespMap(sqlWeight3);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
Assert.assertEquals(deduplicationSqlResp.size(), 2);
}
}

View File

@@ -1,56 +0,0 @@
package com.tencent.supersonic.chat.core.s2sql;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.mockito.Mockito.when;
class LLMSqlParserTest {
@Test
void setFilter() {
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
List<SchemaElement> dimensions = new ArrayList<>();
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
SchemaValueMap value1 = new SchemaValueMap();
value1.setBizName("杰伦");
value1.setTechName("周杰伦");
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
schemaValueMaps.add(value1);
SchemaElement schemaElement = SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.dataSet(2L)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(schemaElement);
SchemaElement schemaElement2 = SchemaElement.builder()
.bizName("publish_time")
.name("发布时间")
.dataSet(2L)
.build();
dimensions.add(schemaElement2);
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
List<SchemaElement> metrics = new ArrayList<>();
SchemaElement metric = SchemaElement.builder()
.bizName("play_count")
.name("播放量")
.dataSet(2L)
.build();
metrics.add(metric);
when(mockSemanticSchema.getMetrics()).thenReturn(metrics);
}
}

View File

@@ -1,123 +0,0 @@
package com.tencent.supersonic.chat.core.utils;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
class S2SqlDateHelperTest {
@Test
void getReferenceDate() {
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
timeDefaultConfig.setUnit(20);
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20));
timeDefaultConfig.setUnit(1);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1));
timeDefaultConfig.setUnit(-1);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
Assert.assertNull(referenceDate);
}
@Test
void getStartEndDate() {
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.DETAIL);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
Assert.assertNull(startEndDate.getLeft());
Assert.assertNull(startEndDate.getRight());
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
timeDefaultConfig.setUnit(20);
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
timeDefaultConfig.setUnit(2);
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
timeDefaultConfig.setUnit(-1);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
Assert.assertNull(startEndDate.getLeft());
Assert.assertNull(startEndDate.getRight());
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
timeDefaultConfig.setUnit(5);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5));
}
private QueryContext buildQueryContext(Long dataSetId) {
QueryContext queryContext = new QueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
QueryConfig queryConfig = new QueryConfig();
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
dataSetSchema.setDataSet(schemaElement);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}
}