From 53a9f7c451bbf5c57e2691a4a27d15569d89c214 Mon Sep 17 00:00:00 2001 From: yudong Date: Fri, 2 Aug 2024 14:14:29 +0800 Subject: [PATCH] =?UTF-8?q?[improvement][headless&chat]=E5=BD=93LLM?= =?UTF-8?q?=E7=94=9F=E6=88=90SQL=E5=8C=85=E5=90=AB=E6=97=A5=E6=9C=9F?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E5=AD=97=E6=AE=B5=E6=97=B6=EF=BC=8CCorrecter?= =?UTF-8?q?=E4=B8=8D=E5=86=8D=E9=A2=9D=E5=A4=96=E5=A2=9E=E5=8A=A0=E6=97=A5?= =?UTF-8?q?=E6=9C=9F=20(#1473)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../headless/api/pojo/SchemaElement.java | 2 ++ .../headless/chat/corrector/TimeCorrector.java | 17 +++++++++++++++++ .../server/utils/DataSetSchemaBuilder.java | 3 +++ .../headless/server/utils/ModelConverter.java | 12 ++++++++++-- 4 files changed, 32 insertions(+), 2 deletions(-) diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index 98a846a99..a0c48e1e7 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.api.pojo; import com.google.common.base.Objects; +import com.tencent.supersonic.headless.api.pojo.enums.SemanticType; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -25,6 +26,7 @@ public class SchemaElement implements Serializable { private String bizName; private Long useCnt; private SchemaElementType type; + private SemanticType semanticType; private List alias; private List schemaValueMaps; private List relatedSchemaElements; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index bd68e2eeb..b54fcc2ab 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -8,6 +8,8 @@ import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; @@ -56,6 +58,15 @@ public class TimeCorrector extends BaseSemanticCorrector { } } + private boolean checkIfNameInWhereFields(Set dims, List whereFields) { + for (SchemaElement element : dims) { + if (whereFields.contains(element.getName())) { + return true; + } + } + return false; + } + private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); List whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); @@ -66,6 +77,12 @@ public class TimeCorrector extends BaseSemanticCorrector { if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) { return; } + Long dataSetId = semanticParseInfo.getDataSetId(); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); + boolean isDateInWhere = checkIfNameInWhereFields(dataSetSchema.getDimensions(), whereFields); + if (isDateInWhere) { + return; + } if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) { Pair startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java index 0f693b38d..44a5a90ba 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaItem; import com.tencent.supersonic.headless.api.pojo.SchemaValueMap; +import com.tencent.supersonic.headless.api.pojo.enums.SemanticType; import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; @@ -154,6 +155,7 @@ public class DataSetSchemaBuilder { schemaValueMaps.add(schemaValueMap); } } + SemanticType semanticType = SemanticType.valueOf(dim.getSemanticType()); SchemaElement dimToAdd = SchemaElement.builder() .dataSetId(resp.getId()) .dataSetName(resp.getName()) @@ -162,6 +164,7 @@ public class DataSetSchemaBuilder { .name(dim.getName()) .bizName(dim.getBizName()) .type(SchemaElementType.DIMENSION) + .semanticType(semanticType) .useCnt(dim.getUseCnt()) .alias(alias) .schemaValueMaps(schemaValueMaps) diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java index ab309e1f4..1ffa5e581 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java @@ -108,7 +108,11 @@ public class ModelConverter { dimensionReq.setName(dim.getName()); dimensionReq.setBizName(dim.getBizName()); dimensionReq.setDescription(dim.getName()); - dimensionReq.setSemanticType(SemanticType.CATEGORY.name()); + if (Objects.equals(dim.getType(), DimensionType.time.name())) { + dimensionReq.setSemanticType(SemanticType.DATE.name()); + } else { + dimensionReq.setSemanticType(SemanticType.CATEGORY.name()); + } dimensionReq.setModelId(modelDO.getId()); dimensionReq.setExpr(dim.getBizName()); dimensionReq.setType(DimensionType.categorical.name()); @@ -138,7 +142,11 @@ public class ModelConverter { dimensionReq.setName(identify.getName()); dimensionReq.setBizName(identify.getBizName()); dimensionReq.setDescription(identify.getName()); - dimensionReq.setSemanticType(SemanticType.CATEGORY.name()); + if (Objects.equals(identify.getType(), DimensionType.time.name())) { + dimensionReq.setSemanticType(SemanticType.DATE.name()); + } else { + dimensionReq.setSemanticType(SemanticType.CATEGORY.name()); + } dimensionReq.setModelId(modelDO.getId()); dimensionReq.setExpr(identify.getBizName()); dimensionReq.setType(identify.getType());