(improvement)(chat) logic sql show in chinese and convert to bizName in execute (#156)

This commit is contained in:
lexluo09
2023-09-27 17:27:31 +08:00
committed by GitHub
parent f931951ad5
commit 617db611c3
17 changed files with 138 additions and 100 deletions

View File

@@ -6,10 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@@ -23,14 +23,11 @@ import org.springframework.util.CollectionUtils;
@Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector {
public static final String DATE_FIELD = "数据日期";
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
}
protected Map<String, String> getFieldToBizName(Long modelId) {
protected Map<String, String> getFieldNameMap(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
@@ -40,8 +37,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
Map<String, String> result = dbAllFields.stream()
.filter(entry -> entry.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName());
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getName(), (k1, k2) -> k1));
result.put(DateUtils.DATE_FIELD, DateUtils.DATE_FIELD);
return result;
}
@@ -55,9 +52,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
whereFields.remove(DateUtils.DATE_FIELD);
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
semanticCorrectInfo.setSql(replaceFields);
}
@@ -75,7 +70,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
}
return schemaElement;
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
}).collect(Collectors.toMap(a -> a.getName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;

View File

@@ -27,7 +27,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
updateFieldNameByLinkingValue(semanticCorrectInfo);
updateFieldNameByBizName(semanticCorrectInfo);
correctFieldName(semanticCorrectInfo);
addAggregateToMetric(semanticCorrectInfo);
}
@@ -39,11 +39,11 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
semanticCorrectInfo.setSql(replaceAlias);
}
private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) {
private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) {
Map<String, String> fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId());
Map<String, String> fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId());
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName);
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap);
semanticCorrectInfo.setSql(sql);
}

View File

@@ -3,10 +3,10 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@@ -28,8 +28,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
.filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName()))
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
.filter(schemaElement -> !DateUtils.DATE_FIELD.equals(schemaElement.getBizName()))
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);

View File

@@ -25,7 +25,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) {
return;

View File

@@ -8,11 +8,11 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -70,9 +70,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) {
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate);
sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
}
semanticCorrectInfo.setSql(sql);
}
@@ -83,7 +83,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
@@ -117,11 +117,11 @@ public class WhereCorrector extends BaseSemanticCorrector {
for (SchemaElement dimension : dimensions) {
if (Objects.isNull(dimension)
|| Strings.isEmpty(dimension.getBizName())
|| Strings.isEmpty(dimension.getName())
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
continue;
}
String bizName = dimension.getBizName();
String name = dimension.getName();
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
@@ -141,7 +141,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
}
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
result.put(bizName, aliasAndBizNameToTechName);
result.put(name, aliasAndBizNameToTechName);
}
}
return result;

View File

@@ -16,7 +16,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.corrector.BaseSemanticCorrector;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver;
import com.tencent.supersonic.chat.query.QueryManager;
@@ -31,11 +30,11 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.ArrayList;
import java.util.Arrays;
@@ -113,7 +112,7 @@ public class LLMDslParser implements SemanticParser {
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getBizName())
&& allFields.contains(schemaElement.getName())
).collect(Collectors.toSet());
}
@@ -122,7 +121,7 @@ public class LLMDslParser implements SemanticParser {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.getNameList().contains(entry))
.filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
@@ -130,6 +129,7 @@ public class LLMDslParser implements SemanticParser {
String correctorSql = semanticCorrectInfo.getSql();
parseInfo.getSqlInfo().setLogicSql(correctorSql);
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
//set dataInfo
try {
@@ -143,8 +143,8 @@ public class LLMDslParser implements SemanticParser {
//set filter
try {
Map<String, SchemaElement> bizNameToElement = getBizNameToElement(modelId);
List<QueryFilter> result = getDimensionFilter(bizNameToElement, expressions);
Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelId);
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
@@ -173,20 +173,18 @@ public class LLMDslParser implements SemanticParser {
}
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement,
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
String bizName = expression.getFieldName();
SchemaElement schemaElement = bizNameToElement.get(bizName);
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
String fieldName = schemaElement.getName();
dimensionFilter.setName(fieldName);
dimensionFilter.setBizName(bizName);
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
@@ -198,13 +196,8 @@ public class LLMDslParser implements SemanticParser {
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> {
List<String> nameList = TimeDimensionEnum.getNameList();
if (StringUtils.isEmpty(expression.getFieldName())) {
return false;
}
return nameList.contains(expression.getFieldName().toLowerCase());
}).collect(Collectors.toList());
.filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
}
@@ -354,7 +347,7 @@ public class LLMDslParser implements SemanticParser {
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
fieldNameList.add(DateUtils.DATE_FIELD);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
@@ -391,7 +384,7 @@ public class LLMDslParser implements SemanticParser {
}
protected Map<String, SchemaElement> getBizNameToElement(Long modelId) {
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
@@ -401,7 +394,7 @@ public class LLMDslParser implements SemanticParser {
allElements.addAll(metrics);
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getBizName, Function.identity(), (value1, value2) -> value2));
.collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2));
}

View File

@@ -68,7 +68,7 @@ class LLMDslParserTest {
model.setId(2L);
parseInfo.setModel(model);
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ")
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 ")
.parseInfo(parseInfo)
.build();