mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-20 06:34:55 +00:00
(improvement)(chat) logic sql show in chinese and convert to bizName in execute (#156)
This commit is contained in:
@@ -6,10 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -23,14 +23,11 @@ import org.springframework.util.CollectionUtils;
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public static final String DATE_FIELD = "数据日期";
|
||||
|
||||
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
||||
}
|
||||
|
||||
protected Map<String, String> getFieldToBizName(Long modelId) {
|
||||
protected Map<String, String> getFieldNameMap(Long modelId) {
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
@@ -40,8 +37,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> entry.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
|
||||
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName());
|
||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getName(), (k1, k2) -> k1));
|
||||
result.put(DateUtils.DATE_FIELD, DateUtils.DATE_FIELD);
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -55,9 +52,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
whereFields.remove(DateUtils.DATE_FIELD);
|
||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||
semanticCorrectInfo.setSql(replaceFields);
|
||||
}
|
||||
@@ -75,7 +70,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||
}
|
||||
return schemaElement;
|
||||
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
|
||||
}).collect(Collectors.toMap(a -> a.getName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
|
||||
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
|
||||
@@ -27,7 +27,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
updateFieldNameByLinkingValue(semanticCorrectInfo);
|
||||
|
||||
updateFieldNameByBizName(semanticCorrectInfo);
|
||||
correctFieldName(semanticCorrectInfo);
|
||||
|
||||
addAggregateToMetric(semanticCorrectInfo);
|
||||
}
|
||||
@@ -39,11 +39,11 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
semanticCorrectInfo.setSql(replaceAlias);
|
||||
}
|
||||
|
||||
private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
Map<String, String> fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId());
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId());
|
||||
|
||||
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName);
|
||||
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap);
|
||||
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ package com.tencent.supersonic.chat.corrector;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -28,8 +28,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
|
||||
.filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName()))
|
||||
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
|
||||
.filter(schemaElement -> !DateUtils.DATE_FIELD.equals(schemaElement.getBizName()))
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
|
||||
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
|
||||
@@ -8,11 +8,11 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -70,9 +70,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) {
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
|
||||
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate);
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
|
||||
}
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
@@ -83,7 +83,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
return queryFilters.getFilters().stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
@@ -117,11 +117,11 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
for (SchemaElement dimension : dimensions) {
|
||||
if (Objects.isNull(dimension)
|
||||
|| Strings.isEmpty(dimension.getBizName())
|
||||
|| Strings.isEmpty(dimension.getName())
|
||||
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
|
||||
continue;
|
||||
}
|
||||
String bizName = dimension.getBizName();
|
||||
String name = dimension.getName();
|
||||
|
||||
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
|
||||
|
||||
@@ -141,7 +141,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
|
||||
result.put(bizName, aliasAndBizNameToTechName);
|
||||
result.put(name, aliasAndBizNameToTechName);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
||||
@@ -16,7 +16,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.corrector.BaseSemanticCorrector;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
@@ -31,11 +30,11 @@ import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
@@ -113,7 +112,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
||||
&& allFields.contains(schemaElement.getBizName())
|
||||
&& allFields.contains(schemaElement.getName())
|
||||
).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
@@ -122,7 +121,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return allFields.stream()
|
||||
.filter(entry -> !TimeDimensionEnum.getNameList().contains(entry))
|
||||
.filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -130,6 +129,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
|
||||
String correctorSql = semanticCorrectInfo.getSql();
|
||||
parseInfo.getSqlInfo().setLogicSql(correctorSql);
|
||||
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
||||
//set dataInfo
|
||||
try {
|
||||
@@ -143,8 +143,8 @@ public class LLMDslParser implements SemanticParser {
|
||||
|
||||
//set filter
|
||||
try {
|
||||
Map<String, SchemaElement> bizNameToElement = getBizNameToElement(modelId);
|
||||
List<QueryFilter> result = getDimensionFilter(bizNameToElement, expressions);
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelId);
|
||||
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
||||
parseInfo.getDimensionFilters().addAll(result);
|
||||
} catch (Exception e) {
|
||||
log.error("set dimensionFilter error :", e);
|
||||
@@ -173,20 +173,18 @@ public class LLMDslParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement,
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FilterExpression> filterExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FilterExpression expression : filterExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
dimensionFilter.setValue(expression.getFieldValue());
|
||||
String bizName = expression.getFieldName();
|
||||
SchemaElement schemaElement = bizNameToElement.get(bizName);
|
||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
continue;
|
||||
}
|
||||
String fieldName = schemaElement.getName();
|
||||
dimensionFilter.setName(fieldName);
|
||||
dimensionFilter.setBizName(bizName);
|
||||
dimensionFilter.setName(schemaElement.getName());
|
||||
dimensionFilter.setBizName(schemaElement.getBizName());
|
||||
dimensionFilter.setElementID(schemaElement.getId());
|
||||
|
||||
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
||||
@@ -198,13 +196,8 @@ public class LLMDslParser implements SemanticParser {
|
||||
|
||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
||||
.filter(expression -> {
|
||||
List<String> nameList = TimeDimensionEnum.getNameList();
|
||||
if (StringUtils.isEmpty(expression.getFieldName())) {
|
||||
return false;
|
||||
}
|
||||
return nameList.contains(expression.getFieldName().toLowerCase());
|
||||
}).collect(Collectors.toList());
|
||||
.filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName()))
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(dateExpressions)) {
|
||||
return new DateConf();
|
||||
}
|
||||
@@ -354,7 +347,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
|
||||
|
||||
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
|
||||
fieldNameList.add(DateUtils.DATE_FIELD);
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
@@ -391,7 +384,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
}
|
||||
|
||||
|
||||
protected Map<String, SchemaElement> getBizNameToElement(Long modelId) {
|
||||
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics();
|
||||
@@ -401,7 +394,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
allElements.addAll(metrics);
|
||||
return allElements.stream()
|
||||
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, Function.identity(), (value1, value2) -> value2));
|
||||
.collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class LLMDslParserTest {
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ")
|
||||
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 ")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user