diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java index 96790faa1..a7f578a35 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java @@ -1,10 +1,5 @@ package com.tencent.supersonic.common.jsqlparser; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Objects; -import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.BinaryExpression; @@ -33,6 +28,12 @@ import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; + /** * Sql Parser remove Helper */ @@ -228,7 +229,6 @@ public class SqlRemoveHelper { if (selectStatement == null) { return sql; } - //SelectBody selectBody = selectStatement.getSelectBody(); if (!(selectStatement instanceof PlainSelect)) { return sql; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java index 0bb5d46d1..a7caa7b70 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java @@ -61,14 +61,15 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return elements.stream(); }) .collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1)); + if (chatQueryContext.containsPartitionDimensions(dataSetId)) { + result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName()); + result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName()); + result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName()); - result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName()); - result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName()); - result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName()); - - result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName()); - result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName()); - result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName()); + result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName()); + result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName()); + result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName()); + } return result; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index 52022edb3..c609e2cfa 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -47,6 +47,7 @@ public class TimeCorrector extends BaseSemanticCorrector { removeFieldNames.add(TimeDimensionEnum.WEEK.getChName()); removeFieldNames.add(TimeDimensionEnum.MONTH.getChName()); correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); + correctS2SQL = SqlRemoveHelper.removeGroupBy(correctS2SQL, removeFieldNames); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index eeedb8d66..a9a4424a8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker; @@ -231,38 +230,24 @@ public class LLMRequestService { protected List getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); - Set results = new HashSet<>(); - - if (!CollectionUtils.isEmpty(matchedElements)) { - results = matchedElements.stream() - .filter(element -> SchemaElementType.DIMENSION.equals(element.getElement().getType())) - .map(SchemaElementMatch::getElement) - .collect(Collectors.toSet()); - } - + Set dimensionElements = matchedElements.stream() + .filter(element -> SchemaElementType.DIMENSION.equals(element.getElement().getType())) + .map(SchemaElementMatch::getElement) + .collect(Collectors.toSet()); SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { - return new ArrayList<>(results); + return new ArrayList<>(dimensionElements); } - DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); + Map dataSetSchemaMap = semanticSchema.getDataSetSchemaMap(); + DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId); if (dataSetSchema == null) { - return new ArrayList<>(results); + return new ArrayList<>(dimensionElements); } - - TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig(); SchemaElement partitionDimension = dataSetSchema.getPartitionDimension(); - - if (timeDefaultConfig == null || partitionDimension == null) { - return new ArrayList<>(results); + if (partitionDimension != null) { + dimensionElements.add(partitionDimension); } - - if (Objects.equals(timeDefaultConfig.getUnit(), -1)) { - results.remove(partitionDimension); - } else { - results.add(partitionDimension); - } - - return new ArrayList<>(results); + return new ArrayList<>(dimensionElements); } }