[improvement](chat) rule and llm supports replace metric (#408)

This commit is contained in:
mainmain
2023-11-20 17:04:14 +08:00
committed by GitHub
parent e15e44e4a2
commit 5b8fdbc6fd
7 changed files with 143 additions and 53 deletions

View File

@@ -12,6 +12,7 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -67,6 +68,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
correctS2SQL = SqlParserRemoveHelper.removeNumberCondition(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
@@ -77,8 +79,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
if (StringUtils.isNotBlank(currentDate)) {
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(),
currentDate);
correctS2SQL = SqlParserAddHelper.addWhere(
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
}
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);

View File

@@ -273,36 +273,21 @@ public class QueryServiceImpl implements QueryService {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo);
if (S2SQLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
List<String> metrics = queryData.getMetrics().stream().map(o -> o.getName()).collect(Collectors.toList());
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
List<FieldExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
correctorSql = SqlParserRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
correctorSql = SqlParserAddHelper.addWhere(correctorSql, addWhereConditions);
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
correctorSql = SqlParserRemoveHelper.removeNumberCondition(correctorSql);
fields = SqlParserSelectHelper.getAllFields(correctorSql);
}
if (CollectionUtils.isNotEmpty(fields) && !fields.containsAll(metrics)
&& CollectionUtils.isNotEmpty(queryData.getMetrics())) {
//replace metrics
log.info("llm begin replace metrics!");
replaceMetrics(parseInfo, metrics);
} else if (S2SQLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
String explainSql = semanticQuery.explain(user);
@@ -310,6 +295,10 @@ public class QueryServiceImpl implements QueryService {
parseInfo.getSqlInfo().setQuerySQL(explainSql);
}
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql
semanticQuery.initS2Sql(user);
QueryReq queryReq = new QueryReq();
@@ -324,6 +313,50 @@ public class QueryServiceImpl implements QueryService {
return queryResult;
}
public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
List<FieldExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
correctorSql = SqlParserRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
correctorSql = SqlParserAddHelper.addWhere(correctorSql, addWhereConditions);
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
correctorSql = SqlParserRemoveHelper.removeNumberCondition(correctorSql);
return correctorSql;
}
private void replaceMetrics(SemanticParseInfo parseInfo, List<String> metrics) {
List<String> filteredMetrics = parseInfo.getMetrics().stream()
.map(o -> o.getName()).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("before replaceMetrics:{}", correctorSql);
correctorSql = SqlParserAddHelper.addFieldsToSelect(correctorSql, metrics);
correctorSql = SqlParserRemoveHelper.removeSelect(correctorSql, filteredMetrics);
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
}
@Override
public EntityInfo getEntityInfo(Long queryId, Integer parseId, User user) {
ChatParseDO chatParseDO = chatService.getParseInfo(queryId, parseId);
@@ -520,12 +553,27 @@ public class QueryServiceImpl implements QueryService {
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
if (CollectionUtils.isNotEmpty(queryData.getMetricFilters())) {
parseInfo.setMetricFilters(queryData.getMetricFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
return parseInfo;
}
private void validFilter(Set<QueryFilter> filters) {
for (QueryFilter queryFilter : filters) {
if (Objects.isNull(queryFilter.getValue())) {
filters.remove(queryFilter);
}
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty(
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) {
filters.remove(queryFilter);
}
}
}
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
QueryResultWithSchemaResp queryResultWithSchemaResp = new QueryResultWithSchemaResp();