[improvement][chat]llm parser corrector is simplified by sql distribution (#120)

This commit is contained in:
lexluo09
2023-09-21 21:57:06 +08:00
committed by GitHub
parent 5c3fd75ed4
commit 03a4719aed
19 changed files with 191 additions and 402 deletions

View File

@@ -1,27 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class DateFieldCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
}
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FieldCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFields(preSql,
getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()));
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionAliasCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(replaceAlias);
}
}

View File

@@ -1,17 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -16,11 +16,39 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldNameCorrector extends BaseSemanticCorrector {
public class GlobalCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
replaceAlias(semanticCorrectInfo);
updateFieldNameByLinkingValue(semanticCorrectInfo);
updateFieldNameByBizName(semanticCorrectInfo);
addAggregateToMetric(semanticCorrectInfo);
}
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
}
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(replaceAlias);
}
private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) {
Map<String, String> fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId());
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName);
semanticCorrectInfo.setSql(sql);
}
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return;
@@ -45,5 +73,4 @@ public class FieldNameCorrector extends BaseSemanticCorrector {
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
semanticCorrectInfo.setSql(sql);
}
}
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
}
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
}
}

View File

@@ -1,48 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class QueryFilterAppend extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql);
}
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
}

View File

@@ -13,11 +13,12 @@ import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
@Slf4j
public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql);
if (Objects.nonNull(havingExpression)) {

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class TableNameCorrector extends BaseSemanticCorrector {
public class TableCorrector extends BaseSemanticCorrector {
public static final String TABLE_PREFIX = "t_";

View File

@@ -1,26 +1,92 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldValueCorrector extends BaseSemanticCorrector {
public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
addDateIfNotExist(semanticCorrectInfo);
parserDateDiffFunction(semanticCorrectInfo);
addQueryFilter(semanticCorrectInfo);
updateFieldValueByTechName(semanticCorrectInfo);
}
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql);
}
}
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
semanticCorrectInfo.setSql(sql);
}
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate);
}
semanticCorrectInfo.setSql(sql);
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
@@ -39,7 +105,6 @@ public class FieldValueCorrector extends BaseSemanticCorrector {
return;
}
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
if (CollectionUtils.isEmpty(dimensions)) {
return new HashMap<>();

View File

@@ -408,27 +408,20 @@ public class LLMDslParser implements SemanticParser {
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
results.addAll(fieldNameList);
return new ArrayList<>(results);
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
Set<String> results = semanticSchema.getDimensions().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>(results);
return new HashSet<>();
}
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
@@ -447,13 +440,29 @@ public class LLMDslParser implements SemanticParser {
})
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
.collect(Collectors.toSet());
results.addAll(fieldNameList);
return new ArrayList<>(results);
return fieldNameList;
}
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
Set<String> results = semanticSchema.getDimensions(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
return results;
}
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
return semanticSchema.getDimensions().stream()
.filter(entry -> modelId.equals(entry.getModel()))
return semanticSchema.getDimensions(modelId).stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}