From 0f02e21eaa904fb61e76eb59b17947028959a12e Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Mon, 13 Nov 2023 15:36:27 +0800 Subject: [PATCH] (improvement)(chat) rename logicSql to correctS2SQL and the rule does not perform corrector operation (#373) --- .../chat/corrector/BaseSemanticCorrector.java | 14 +++++----- .../chat/corrector/GroupByCorrector.java | 12 ++++---- .../chat/corrector/HavingCorrector.java | 8 +++--- .../chat/corrector/SelectCorrector.java | 8 +++--- .../chat/corrector/WhereCorrector.java | 28 +++++++++---------- .../chat/query/BaseSemanticQuery.java | 14 ++++------ .../service/impl/ParserInfoServiceImpl.java | 17 ++++++----- .../chat/service/impl/QueryServiceImpl.java | 5 ++++ 8 files changed, 56 insertions(+), 50 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index 30f0ab9d8..fc6166769 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -71,10 +71,10 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return result; } - protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String logicSql) { - Set selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(logicSql)); - Set needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(logicSql)); - needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(logicSql)); + protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) { + Set selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL)); + Set needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL)); + needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL)); if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) { return; @@ -82,13 +82,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { needAddFields.removeAll(selectFields); needAddFields.remove(TimeDimensionEnum.DAY.getChName()); - String replaceFields = SqlParserAddHelper.addFieldsToSelect(logicSql, new ArrayList<>(needAddFields)); + String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields)); semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields); } protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) { //add aggregate to all metric - String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); Long modelId = semanticParseInfo.getModel().getModel(); List metrics = getMetricElements(modelId); @@ -104,7 +104,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { if (CollectionUtils.isEmpty(metricToAggregate)) { return; } - String aggregateSql = SqlParserAddHelper.addAggregateToField(logicSql, metricToAggregate); + String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate); semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java index bbd00d197..c7147d720 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -31,7 +31,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { //add dimension group by SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String logicSql = sqlInfo.getCorrectS2SQL(); + String correctS2SQL = sqlInfo.getCorrectS2SQL(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); //add alias field name Set dimensions = semanticSchema.getDimensions(modelId).stream() @@ -47,7 +47,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { ).collect(Collectors.toSet()); dimensions.add(TimeDimensionEnum.DAY.getChName()); - List selectFields = SqlParserSelectHelper.getSelectFields(logicSql); + List selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL); if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) { return; @@ -56,12 +56,12 @@ public class GroupByCorrector extends BaseSemanticCorrector { if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) { return; } - if (SqlParserSelectHelper.hasGroupBy(logicSql)) { - log.info("not add group by ,exist group by in logicSql:{}", logicSql); + if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) { + log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL); return; } - List aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql); + List aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL); Set groupByFields = selectFields.stream() .filter(field -> dimensions.contains(field)) .filter(field -> { @@ -71,7 +71,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { return true; }) .collect(Collectors.toSet()); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(logicSql, groupByFields)); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields)); addAggregate(semanticParseInfo); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java index 3de527134..7a1f2bba2 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -44,13 +44,13 @@ public class HavingCorrector extends BaseSemanticCorrector { } private void addHavingToSelect(SemanticParseInfo semanticParseInfo) { - String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); - if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) { + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) { return; } - Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql); + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(correctS2SQL); if (Objects.nonNull(havingExpression)) { - String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression); + String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpression); semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql); } return; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java index 3135ac420..3036fbb82 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java @@ -12,15 +12,15 @@ public class SelectCorrector extends BaseSemanticCorrector { @Override public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { - String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); - List aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql); - List selectFields = SqlParserSelectHelper.getSelectFields(logicSql); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + List aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL); + List selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL); // If the number of aggregated fields is equal to the number of queried fields, do not add fields to select. if (!CollectionUtils.isEmpty(aggregateFields) && !CollectionUtils.isEmpty(selectFields) && aggregateFields.size() == selectFields.size()) { return; } - addFieldsToSelect(semanticParseInfo, logicSql); + addFieldsToSelect(semanticParseInfo, correctS2SQL); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index ed5a7f037..6fc2f3099 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -46,38 +46,38 @@ public class WhereCorrector extends BaseSemanticCorrector { private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { String queryFilter = getQueryFilter(queryReq.getQueryFilters()); - String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); if (StringUtils.isNotEmpty(queryFilter)) { - log.info("add queryFilter to logicSql :{}", queryFilter); + log.info("add queryFilter to correctS2SQL :{}", queryFilter); Expression expression = null; try { expression = CCJSqlParserUtil.parseCondExpression(queryFilter); } catch (JSQLParserException e) { log.error("parseCondExpression", e); } - logicSql = SqlParserAddHelper.addWhere(logicSql, expression); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql); + correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); } } private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) { - String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); - logicSql = SqlParserReplaceHelper.replaceFunction(logicSql); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); } private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) { - String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); - List whereFields = SqlParserSelectHelper.getWhereFields(logicSql); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + List whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL); if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getChName())) { String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId()); if (StringUtils.isNotBlank(currentDate)) { - logicSql = SqlParserAddHelper.addParenthesisToWhere(logicSql); - logicSql = SqlParserAddHelper.addWhere(logicSql, TimeDimensionEnum.DAY.getChName(), currentDate); + correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL); + correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate); } } - semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); } private String getQueryFilter(QueryFilters queryFilters) { @@ -106,9 +106,9 @@ public class WhereCorrector extends BaseSemanticCorrector { } Map> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions); - String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), + String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), aliasAndBizNameToTechName); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); return; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java index da92e43fd..293056504 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java @@ -41,26 +41,24 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { public String explain(User user) { ExplainSqlReq explainSqlReq = null; try { - ExplainResp explain = null; SqlInfo sqlInfo = parseInfo.getSqlInfo(); + if (StringUtils.isNotBlank(sqlInfo.getCorrectS2SQL())) { //sql - QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(sqlInfo.getCorrectS2SQL(), - parseInfo.getModelId()); explainSqlReq = ExplainSqlReq.builder() .queryTypeEnum(QueryTypeEnum.SQL) - .queryReq(queryS2SQLReq) + .queryReq(QueryReqBuilder.buildS2SQLReq( + sqlInfo.getCorrectS2SQL(), parseInfo.getModelId() + )) .build(); - explain = semanticInterpreter.explain(explainSqlReq, user); } else { //struct - QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo); explainSqlReq = ExplainSqlReq.builder() .queryTypeEnum(QueryTypeEnum.STRUCT) - .queryReq(queryStructReq) + .queryReq(QueryReqBuilder.buildStructReq(parseInfo)) .build(); - explain = semanticInterpreter.explain(explainSqlReq, user); } + ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user); if (Objects.nonNull(explain)) { return explain.getSql(); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java index b7ec8b744..d1f328b23 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java @@ -38,12 +38,16 @@ public class ParserInfoServiceImpl implements ParseInfoService { public void updateParseInfo(SemanticParseInfo parseInfo) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); - String logicSql = sqlInfo.getCorrectS2SQL(); - if (StringUtils.isBlank(logicSql)) { + String correctS2SQL = sqlInfo.getCorrectS2SQL(); + if (StringUtils.isBlank(correctS2SQL)) { + return; + } + // if S2SQL equals correctS2SQL, than not update the parseInfo. + if (correctS2SQL.equals(sqlInfo.getS2SQL())) { return; } - List expressions = SqlParserSelectHelper.getFilterExpression(logicSql); + List expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL); //set dataInfo try { if (!CollectionUtils.isEmpty(expressions)) { @@ -70,10 +74,9 @@ public class ParserInfoServiceImpl implements ParseInfoService { if (Objects.isNull(semanticSchema)) { return; } - //cannot use metrics in sql to override parse info - //List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL())); - //Set metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics()); - //parseInfo.setMetrics(metrics); + List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL())); + Set metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics()); + parseInfo.setMetrics(metrics); if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) { parseInfo.setNativeQuery(false); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 5d681af5c..2801623e0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -29,6 +29,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryRanker; import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery; +import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.responder.execute.ExecuteResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder; import com.tencent.supersonic.chat.service.ChatService; @@ -148,6 +149,10 @@ public class QueryServiceImpl implements QueryService { if (CollectionUtils.isNotEmpty(candidateQueries)) { for (SemanticQuery semanticQuery : candidateQueries) { semanticQuery.initS2Sql(queryReq.getUser()); + // rule + if (semanticQuery instanceof RuleSemanticQuery) { + continue; + } semanticCorrectors.stream().forEach(correction -> { correction.correct(queryReq, semanticQuery.getParseInfo()); });