[improvement][headless&chat]当LLM生成SQL包含日期类型字段时,Correcter不再额外增加日期 (#1473)

This commit is contained in:
yudong
2024-08-02 14:14:29 +08:00
committed by GitHub
parent e26263d229
commit 53a9f7c451
4 changed files with 32 additions and 2 deletions

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
@@ -25,6 +26,7 @@ public class SchemaElement implements Serializable {
private String bizName; private String bizName;
private Long useCnt; private Long useCnt;
private SchemaElementType type; private SchemaElementType type;
private SemanticType semanticType;
private List<String> alias; private List<String> alias;
private List<SchemaValueMap> schemaValueMaps; private List<SchemaValueMap> schemaValueMaps;
private List<RelatedSchemaElement> relatedSchemaElements; private List<RelatedSchemaElement> relatedSchemaElements;

View File

@@ -8,6 +8,8 @@ import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -56,6 +58,15 @@ public class TimeCorrector extends BaseSemanticCorrector {
} }
} }
private boolean checkIfNameInWhereFields(Set<SchemaElement> dims, List<String> whereFields) {
for (SchemaElement element : dims) {
if (whereFields.contains(element.getName())) {
return true;
}
}
return false;
}
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
@@ -66,6 +77,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) { if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) {
return; return;
} }
Long dataSetId = semanticParseInfo.getDataSetId();
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
boolean isDateInWhere = checkIfNameInWhereFields(dataSetSchema.getDimensions(), whereFields);
if (isDateInWhere) {
return;
}
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) { if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext,

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaItem; import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap; import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
@@ -154,6 +155,7 @@ public class DataSetSchemaBuilder {
schemaValueMaps.add(schemaValueMap); schemaValueMaps.add(schemaValueMap);
} }
} }
SemanticType semanticType = SemanticType.valueOf(dim.getSemanticType());
SchemaElement dimToAdd = SchemaElement.builder() SchemaElement dimToAdd = SchemaElement.builder()
.dataSetId(resp.getId()) .dataSetId(resp.getId())
.dataSetName(resp.getName()) .dataSetName(resp.getName())
@@ -162,6 +164,7 @@ public class DataSetSchemaBuilder {
.name(dim.getName()) .name(dim.getName())
.bizName(dim.getBizName()) .bizName(dim.getBizName())
.type(SchemaElementType.DIMENSION) .type(SchemaElementType.DIMENSION)
.semanticType(semanticType)
.useCnt(dim.getUseCnt()) .useCnt(dim.getUseCnt())
.alias(alias) .alias(alias)
.schemaValueMaps(schemaValueMaps) .schemaValueMaps(schemaValueMaps)

View File

@@ -108,7 +108,11 @@ public class ModelConverter {
dimensionReq.setName(dim.getName()); dimensionReq.setName(dim.getName());
dimensionReq.setBizName(dim.getBizName()); dimensionReq.setBizName(dim.getBizName());
dimensionReq.setDescription(dim.getName()); dimensionReq.setDescription(dim.getName());
dimensionReq.setSemanticType(SemanticType.CATEGORY.name()); if (Objects.equals(dim.getType(), DimensionType.time.name())) {
dimensionReq.setSemanticType(SemanticType.DATE.name());
} else {
dimensionReq.setSemanticType(SemanticType.CATEGORY.name());
}
dimensionReq.setModelId(modelDO.getId()); dimensionReq.setModelId(modelDO.getId());
dimensionReq.setExpr(dim.getBizName()); dimensionReq.setExpr(dim.getBizName());
dimensionReq.setType(DimensionType.categorical.name()); dimensionReq.setType(DimensionType.categorical.name());
@@ -138,7 +142,11 @@ public class ModelConverter {
dimensionReq.setName(identify.getName()); dimensionReq.setName(identify.getName());
dimensionReq.setBizName(identify.getBizName()); dimensionReq.setBizName(identify.getBizName());
dimensionReq.setDescription(identify.getName()); dimensionReq.setDescription(identify.getName());
dimensionReq.setSemanticType(SemanticType.CATEGORY.name()); if (Objects.equals(identify.getType(), DimensionType.time.name())) {
dimensionReq.setSemanticType(SemanticType.DATE.name());
} else {
dimensionReq.setSemanticType(SemanticType.CATEGORY.name());
}
dimensionReq.setModelId(modelDO.getId()); dimensionReq.setModelId(modelDO.getId());
dimensionReq.setExpr(identify.getBizName()); dimensionReq.setExpr(identify.getBizName());
dimensionReq.setType(identify.getType()); dimensionReq.setType(identify.getType());