diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/WhereCorrector.java index 7c471f320..7d6463d44 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/WhereCorrector.java @@ -18,6 +18,7 @@ import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.util.Strings; import org.springframework.util.CollectionUtils; @@ -65,11 +66,20 @@ public class WhereCorrector extends BaseSemanticCorrector { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); List whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) { - String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getViewId()); - if (StringUtils.isNotBlank(currentDate)) { + Pair startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, + semanticParseInfo.getViewId(), semanticParseInfo.getQueryType()); + if (StringUtils.isNotBlank(startEndDate.getLeft()) + && StringUtils.isNotBlank(startEndDate.getRight())) { correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); - correctS2SQL = SqlAddHelper.addWhere( - correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate); + String dateChName = TimeDimensionEnum.DAY.getChName(); + String condExpr = String.format(" ( %s >= %s and %s <= %s )", dateChName, + startEndDate.getLeft(), dateChName, startEndDate.getRight()); + try { + Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr); + correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression); + } catch (JSQLParserException e) { + log.error("parseCondExpression:{}", e); + } } } semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelper.java index 17589b566..5421d47c2 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelper.java @@ -2,11 +2,12 @@ package com.tencent.supersonic.chat.core.parser.sql.llm; import com.tencent.supersonic.chat.api.pojo.ViewSchema; import com.tencent.supersonic.chat.core.pojo.QueryContext; +import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.util.DatePeriodEnum; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; - import java.util.Objects; +import org.apache.commons.lang3.tuple.Pair; public class S2SqlDateHelper { @@ -20,21 +21,46 @@ public class S2SqlDateHelper { return defaultDate; } TimeDefaultConfig tagTypeTimeDefaultConfig = viewSchema.getTagTypeTimeDefaultConfig(); - Integer unit = tagTypeTimeDefaultConfig.getUnit(); - String period = tagTypeTimeDefaultConfig.getPeriod(); + return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft(); + } + + public static Pair getStartEndDate(QueryContext queryContext, + Long viewId, QueryType queryType) { + String defaultDate = DateUtils.getBeforeDate(0); + if (Objects.isNull(viewId)) { + return Pair.of(defaultDate, defaultDate); + } + ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId); + if (viewSchema == null) { + return Pair.of(defaultDate, defaultDate); + } + TimeDefaultConfig defaultConfig = viewSchema.getMetricTypeTimeDefaultConfig(); + if (QueryType.TAG.equals(queryType)) { + defaultConfig = viewSchema.getTagTypeTimeDefaultConfig(); + } + return getDefaultDate(defaultDate, defaultConfig); + } + + private static Pair getDefaultDate(String defaultDate, TimeDefaultConfig defaultConfig) { + if (Objects.isNull(defaultConfig)) { + return Pair.of(null, null); + } + Integer unit = defaultConfig.getUnit(); + String period = defaultConfig.getPeriod(); if (Objects.nonNull(unit)) { // If the unit is set to less than 0, then do not add relative date. if (unit < 0) { - return null; + return Pair.of(null, null); } DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period); if (Objects.isNull(datePeriodEnum)) { - return DateUtils.getBeforeDate(unit); + return Pair.of(DateUtils.getBeforeDate(unit), DateUtils.getBeforeDate(1)); } else { - return DateUtils.getBeforeDate(unit, datePeriodEnum); + return Pair.of(DateUtils.getBeforeDate(unit, datePeriodEnum), + DateUtils.getBeforeDate(1, datePeriodEnum)); } } - return defaultDate; + return Pair.of(defaultDate, defaultDate); } } diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelperTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelperTest.java new file mode 100644 index 000000000..acca8a813 --- /dev/null +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelperTest.java @@ -0,0 +1,103 @@ +package com.tencent.supersonic.chat.core.parser.sql.llm; + +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.ViewSchema; +import com.tencent.supersonic.chat.core.pojo.QueryContext; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.enums.QueryType; +import com.tencent.supersonic.common.pojo.enums.TimeMode; +import com.tencent.supersonic.headless.api.pojo.QueryConfig; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +class S2SqlDateHelperTest { + + @Test + void getReferenceDate() { + Long viewId = 1L; + QueryContext queryContext = buildQueryContext(viewId); + + String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null); + Assert.assertNotNull(referenceDate); + + referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId); + Assert.assertNotNull(referenceDate); + + ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId); + QueryConfig queryConfig = viewSchema.getQueryConfig(); + TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig(); + timeDefaultConfig.setTimeMode(TimeMode.LAST); + timeDefaultConfig.setPeriod(Constants.DAY); + timeDefaultConfig.setUnit(20); + queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig); + + referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId); + Assert.assertNotNull(referenceDate); + + timeDefaultConfig.setUnit(1); + referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId); + Assert.assertNotNull(referenceDate); + + timeDefaultConfig.setUnit(-1); + referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId); + Assert.assertNull(referenceDate); + } + + @Test + void getStartEndDate() { + Long viewId = 1L; + QueryContext queryContext = buildQueryContext(viewId); + + Pair startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.TAG); + Assert.assertNotNull(startEndDate); + + startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG); + Assert.assertNotNull(startEndDate); + + ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId); + QueryConfig queryConfig = viewSchema.getQueryConfig(); + TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig(); + timeDefaultConfig.setTimeMode(TimeMode.LAST); + timeDefaultConfig.setPeriod(Constants.DAY); + timeDefaultConfig.setUnit(20); + queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig); + queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig); + + startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG); + Assert.assertNotNull(startEndDate); + + startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC); + Assert.assertNotNull(startEndDate); + + timeDefaultConfig.setUnit(1); + timeDefaultConfig.setTimeMode(TimeMode.RECENT); + startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC); + Assert.assertNotNull(startEndDate); + + timeDefaultConfig.setUnit(-1); + startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC); + Assert.assertNull(startEndDate.getLeft()); + Assert.assertNull(startEndDate.getRight()); + } + + private QueryContext buildQueryContext(Long viewId) { + QueryContext queryContext = new QueryContext(); + List viewSchemaList = new ArrayList<>(); + ViewSchema viewSchema = new ViewSchema(); + QueryConfig queryConfig = new QueryConfig(); + viewSchema.setQueryConfig(queryConfig); + SchemaElement schemaElement = new SchemaElement(); + schemaElement.setView(viewId); + viewSchema.setView(schemaElement); + viewSchemaList.add(viewSchema); + + SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList); + queryContext.setSemanticSchema(semanticSchema); + return queryContext; + } +} \ No newline at end of file