mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[improvement][headless&chat]当LLM生成SQL包含日期类型字段时,Correcter不再额外增加日期 (#1473)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.headless.api.pojo;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -25,6 +26,7 @@ public class SchemaElement implements Serializable {
|
|||||||
private String bizName;
|
private String bizName;
|
||||||
private Long useCnt;
|
private Long useCnt;
|
||||||
private SchemaElementType type;
|
private SchemaElementType type;
|
||||||
|
private SemanticType semanticType;
|
||||||
private List<String> alias;
|
private List<String> alias;
|
||||||
private List<SchemaValueMap> schemaValueMaps;
|
private List<SchemaValueMap> schemaValueMaps;
|
||||||
private List<RelatedSchemaElement> relatedSchemaElements;
|
private List<RelatedSchemaElement> relatedSchemaElements;
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper;
|
|||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
|
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -56,6 +58,15 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean checkIfNameInWhereFields(Set<SchemaElement> dims, List<String> whereFields) {
|
||||||
|
for (SchemaElement element : dims) {
|
||||||
|
if (whereFields.contains(element.getName())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||||
@@ -66,6 +77,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
|||||||
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) {
|
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||||
|
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||||
|
boolean isDateInWhere = checkIfNameInWhereFields(dataSetSchema.getDimensions(), whereFields);
|
||||||
|
if (isDateInWhere) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||||
|
|
||||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext,
|
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
||||||
@@ -154,6 +155,7 @@ public class DataSetSchemaBuilder {
|
|||||||
schemaValueMaps.add(schemaValueMap);
|
schemaValueMaps.add(schemaValueMap);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
SemanticType semanticType = SemanticType.valueOf(dim.getSemanticType());
|
||||||
SchemaElement dimToAdd = SchemaElement.builder()
|
SchemaElement dimToAdd = SchemaElement.builder()
|
||||||
.dataSetId(resp.getId())
|
.dataSetId(resp.getId())
|
||||||
.dataSetName(resp.getName())
|
.dataSetName(resp.getName())
|
||||||
@@ -162,6 +164,7 @@ public class DataSetSchemaBuilder {
|
|||||||
.name(dim.getName())
|
.name(dim.getName())
|
||||||
.bizName(dim.getBizName())
|
.bizName(dim.getBizName())
|
||||||
.type(SchemaElementType.DIMENSION)
|
.type(SchemaElementType.DIMENSION)
|
||||||
|
.semanticType(semanticType)
|
||||||
.useCnt(dim.getUseCnt())
|
.useCnt(dim.getUseCnt())
|
||||||
.alias(alias)
|
.alias(alias)
|
||||||
.schemaValueMaps(schemaValueMaps)
|
.schemaValueMaps(schemaValueMaps)
|
||||||
|
|||||||
@@ -108,7 +108,11 @@ public class ModelConverter {
|
|||||||
dimensionReq.setName(dim.getName());
|
dimensionReq.setName(dim.getName());
|
||||||
dimensionReq.setBizName(dim.getBizName());
|
dimensionReq.setBizName(dim.getBizName());
|
||||||
dimensionReq.setDescription(dim.getName());
|
dimensionReq.setDescription(dim.getName());
|
||||||
dimensionReq.setSemanticType(SemanticType.CATEGORY.name());
|
if (Objects.equals(dim.getType(), DimensionType.time.name())) {
|
||||||
|
dimensionReq.setSemanticType(SemanticType.DATE.name());
|
||||||
|
} else {
|
||||||
|
dimensionReq.setSemanticType(SemanticType.CATEGORY.name());
|
||||||
|
}
|
||||||
dimensionReq.setModelId(modelDO.getId());
|
dimensionReq.setModelId(modelDO.getId());
|
||||||
dimensionReq.setExpr(dim.getBizName());
|
dimensionReq.setExpr(dim.getBizName());
|
||||||
dimensionReq.setType(DimensionType.categorical.name());
|
dimensionReq.setType(DimensionType.categorical.name());
|
||||||
@@ -138,7 +142,11 @@ public class ModelConverter {
|
|||||||
dimensionReq.setName(identify.getName());
|
dimensionReq.setName(identify.getName());
|
||||||
dimensionReq.setBizName(identify.getBizName());
|
dimensionReq.setBizName(identify.getBizName());
|
||||||
dimensionReq.setDescription(identify.getName());
|
dimensionReq.setDescription(identify.getName());
|
||||||
dimensionReq.setSemanticType(SemanticType.CATEGORY.name());
|
if (Objects.equals(identify.getType(), DimensionType.time.name())) {
|
||||||
|
dimensionReq.setSemanticType(SemanticType.DATE.name());
|
||||||
|
} else {
|
||||||
|
dimensionReq.setSemanticType(SemanticType.CATEGORY.name());
|
||||||
|
}
|
||||||
dimensionReq.setModelId(modelDO.getId());
|
dimensionReq.setModelId(modelDO.getId());
|
||||||
dimensionReq.setExpr(identify.getBizName());
|
dimensionReq.setExpr(identify.getBizName());
|
||||||
dimensionReq.setType(identify.getType());
|
dimensionReq.setType(identify.getType());
|
||||||
|
|||||||
Reference in New Issue
Block a user