From 16c3de44e4a7ffe5306aae4a2e53841bd4059c56 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:03:05 +0800 Subject: [PATCH] (improvement)(chat) Overall code optimization for the corrector (#345) --- .../chat/api/component/SemanticCorrector.java | 6 +- .../chat/api/component/SemanticQuery.java | 7 +- .../chat/config/OptimizationConfig.java | 4 +- .../chat/corrector/BaseSemanticCorrector.java | 42 ++- .../chat/corrector/CorrectorService.java | 13 - .../chat/corrector/CorrectorServiceImpl.java | 95 ------ .../chat/corrector/GlobalAfterCorrector.java | 16 +- .../chat/corrector/GlobalBeforeCorrector.java | 59 ++-- .../chat/corrector/GroupByCorrector.java | 36 ++- .../chat/corrector/HavingCorrector.java | 14 +- .../chat/corrector/SelectCorrector.java | 14 +- .../chat/corrector/WhereCorrector.java | 58 ++-- .../supersonic/chat/mapper/BaseMapper.java | 4 +- .../chat/mapper/BaseMatchStrategy.java | 6 +- .../parser/llm/s2ql/LLMResponseService.java | 198 +----------- .../chat/query/BaseSemanticQuery.java | 291 ++++++++++++++++++ .../llm/interpret/MetricInterpretQuery.java | 24 +- .../chat/query/llm/s2ql/S2QLQuery.java | 20 +- .../query/plugin/PluginSemanticQuery.java | 21 +- .../chat/query/rule/RuleSemanticQuery.java | 45 +-- .../parse/SqlInfoParseResponder.java | 8 +- .../chat/service/SemanticService.java | 4 - .../chat/service/impl/QueryServiceImpl.java | 51 +-- .../chat/utils/DictQueryHelper.java | 11 +- .../api/query/request/QueryStructReq.java | 2 - 25 files changed, 507 insertions(+), 542 deletions(-) delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/CorrectorService.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/CorrectorServiceImpl.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticCorrector.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticCorrector.java index 7eda4759d..1a4441286 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticCorrector.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticCorrector.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.chat.api.component; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import net.sf.jsqlparser.JSQLParserException; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; /** * A semantic corrector checks validity of extracted semantic information and @@ -9,5 +9,5 @@ import net.sf.jsqlparser.JSQLParserException; */ public interface SemanticCorrector { - void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException; + void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo); } diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java index 7cf3cf7cb..8622b0cda 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java @@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.api.component; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; -import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import org.apache.calcite.sql.parser.SqlParseException; /** @@ -15,9 +14,13 @@ public interface SemanticQuery { QueryResult execute(User user) throws SqlParseException; - SqlInfo explain(User user); + void initS2Sql(User user); + + String explain(User user); SemanticParseInfo getParseInfo(); + void updateParseInfo(); + void setParseInfo(SemanticParseInfo parseInfo); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java index 772870053..e5bcc0023 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java @@ -42,7 +42,7 @@ public class OptimizationConfig { @Value("${user.s2ql.switch:false}") private boolean useS2qlSwitch; - @Value("${embedding.mapper.word.min:3}") + @Value("${embedding.mapper.word.min:4}") private int embeddingMapperWordMin; @Value("${embedding.mapper.word.max:5}") @@ -57,6 +57,6 @@ public class OptimizationConfig { @Value("${embedding.mapper.round.number:10}") private int embeddingMapperRoundNumber; - @Value("${embedding.mapper.distance.threshold:0.52}") + @Value("${embedding.mapper.distance.threshold:0.58}") private Double embeddingMapperDistanceThreshold; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index e42fb8676..6b72aeb20 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -2,8 +2,9 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.component.SemanticCorrector; import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; @@ -18,15 +19,27 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @Slf4j public abstract class BaseSemanticCorrector implements SemanticCorrector { - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql()); + public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { + try { + if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getLogicSql())) { + return; + } + work(queryReq, semanticParseInfo); + log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo()); + } catch (Exception e) { + log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e); + } } + + public abstract void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo); + protected Map getFieldNameMap(Long modelId) { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); @@ -58,10 +71,10 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return result; } - protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) { - Set selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql)); - Set needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(sql)); - needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); + protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String logicSql) { + Set selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(logicSql)); + Set needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(logicSql)); + needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(logicSql)); if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) { return; @@ -69,14 +82,14 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { needAddFields.removeAll(selectFields); needAddFields.remove(TimeDimensionEnum.DAY.getChName()); - String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(needAddFields)); - semanticCorrectInfo.setSql(replaceFields); + String replaceFields = SqlParserAddHelper.addFieldsToSelect(logicSql, new ArrayList<>(needAddFields)); + semanticParseInfo.getSqlInfo().setLogicSql(replaceFields); } - protected void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { + protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) { //add aggregate to all metric - String sql = semanticCorrectInfo.getSql(); - Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel(); + String logicSql = semanticParseInfo.getSqlInfo().getLogicSql(); + Long modelId = semanticParseInfo.getModel().getModel(); List metrics = getMetricElements(modelId); @@ -91,9 +104,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { if (CollectionUtils.isEmpty(metricToAggregate)) { return; } - - String aggregateSql = SqlParserAddHelper.addAggregateToField(sql, metricToAggregate); - semanticCorrectInfo.setSql(aggregateSql); + String aggregateSql = SqlParserAddHelper.addAggregateToField(logicSql, metricToAggregate); + semanticParseInfo.getSqlInfo().setLogicSql(aggregateSql); } protected List getMetricElements(Long modelId) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/CorrectorService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/CorrectorService.java deleted file mode 100644 index ae46479a9..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/CorrectorService.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; - -public interface CorrectorService { - - SemanticCorrectInfo correctorSql(QueryFilters queryFilters, SemanticParseInfo parseInfo, String sql); - - void addS2QLAndLoginSql(QueryStructReq queryStructReq, SemanticParseInfo parseInfo); -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/CorrectorServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/CorrectorServiceImpl.java deleted file mode 100644 index 459242709..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/CorrectorServiceImpl.java +++ /dev/null @@ -1,95 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.component.SemanticCorrector; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.chat.utils.ComponentFactory; -import com.tencent.supersonic.common.pojo.Aggregator; -import com.tencent.supersonic.common.pojo.Filter; -import com.tencent.supersonic.common.pojo.Order; -import com.tencent.supersonic.knowledge.service.SchemaService; -import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq; -import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections4.CollectionUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; - -@Slf4j -@Service -public class CorrectorServiceImpl implements CorrectorService { - - @Autowired - private SchemaService schemaService; - - public SemanticCorrectInfo correctorSql(QueryFilters queryFilters, SemanticParseInfo parseInfo, String sql) { - - SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder() - .queryFilters(queryFilters).sql(sql) - .parseInfo(parseInfo).build(); - - List corrections = ComponentFactory.getSqlCorrections(); - - corrections.forEach(correction -> { - try { - correction.correct(correctInfo); - log.info("sqlCorrection:{} sql:{}", correction.getClass().getSimpleName(), correctInfo.getSql()); - } catch (Exception e) { - log.error(String.format("correct error,correctInfo:%s", correctInfo), e); - } - }); - return correctInfo; - } - - - public void addS2QLAndLoginSql(QueryStructReq queryStructReq, SemanticParseInfo parseInfo) { - convertBizNameToName(queryStructReq, parseInfo); - QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq); - parseInfo.getSqlInfo().setS2QL(queryS2QLReq.getSql()); - queryStructReq.setS2QL(queryS2QLReq.getSql()); - - SemanticCorrectInfo semanticCorrectInfo = correctorSql(new QueryFilters(), parseInfo, - queryS2QLReq.getSql()); - parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql()); - - queryStructReq.setLogicSql(semanticCorrectInfo.getSql()); - } - - - private void convertBizNameToName(QueryStructReq queryStructReq, SemanticParseInfo parseInfo) { - Map bizNameToName = schemaService.getSemanticSchema() - .getBizNameToName(queryStructReq.getModelId()); - List orders = queryStructReq.getOrders(); - if (CollectionUtils.isNotEmpty(orders)) { - for (Order order : orders) { - order.setColumn(bizNameToName.get(order.getColumn())); - } - } - List aggregators = queryStructReq.getAggregators(); - if (CollectionUtils.isNotEmpty(aggregators)) { - for (Aggregator aggregator : aggregators) { - aggregator.setColumn(bizNameToName.get(aggregator.getColumn())); - } - } - List groups = queryStructReq.getGroups(); - if (CollectionUtils.isNotEmpty(groups)) { - groups = groups.stream().map(group -> bizNameToName.get(group)).collect(Collectors.toList()); - queryStructReq.setGroups(groups); - } - List dimensionFilters = queryStructReq.getDimensionFilters(); - if (CollectionUtils.isNotEmpty(dimensionFilters)) { - dimensionFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); - } - List metricFilters = queryStructReq.getMetricFilters(); - if (CollectionUtils.isNotEmpty(dimensionFilters)) { - metricFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); - } - - queryStructReq.setModelName(parseInfo.getModelName()); - } - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java index 5d6d6c51d..56fa9b19e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.corrector; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; @@ -12,17 +13,16 @@ import net.sf.jsqlparser.expression.Expression; public class GlobalAfterCorrector extends BaseSemanticCorrector { @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { + public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { - super.correct(semanticCorrectInfo); - String sql = semanticCorrectInfo.getSql(); - if (!SqlParserSelectFunctionHelper.hasAggregateFunction(sql)) { + String logicSql = semanticParseInfo.getSqlInfo().getLogicSql(); + if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) { return; } - Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql); if (Objects.nonNull(havingExpression)) { - String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); - semanticCorrectInfo.setSql(replaceSql); + String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression); + semanticParseInfo.getSqlInfo().setLogicSql(replaceSql); } return; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java index ed56240f3..ba0331612 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java @@ -1,6 +1,8 @@ package com.tencent.supersonic.chat.corrector; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult; import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue; @@ -19,35 +21,32 @@ import org.springframework.util.CollectionUtils; public class GlobalBeforeCorrector extends BaseSemanticCorrector { @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { + public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { - super.correct(semanticCorrectInfo); + replaceAlias(semanticParseInfo); - replaceAlias(semanticCorrectInfo); + updateFieldNameByLinkingValue(semanticParseInfo); - updateFieldNameByLinkingValue(semanticCorrectInfo); + updateFieldValueByLinkingValue(semanticParseInfo); - updateFieldValueByLinkingValue(semanticCorrectInfo); - - correctFieldName(semanticCorrectInfo); + correctFieldName(semanticParseInfo); } - private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { - String replaceAlias = SqlParserReplaceHelper.replaceAlias(semanticCorrectInfo.getSql()); - semanticCorrectInfo.setSql(replaceAlias); + private void replaceAlias(SemanticParseInfo semanticParseInfo) { + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getLogicSql()); + sqlInfo.setLogicSql(replaceAlias); } - private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) { - - Map fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId()); - - String sql = SqlParserReplaceHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap); - - semanticCorrectInfo.setSql(sql); + private void correctFieldName(SemanticParseInfo semanticParseInfo) { + Map fieldNameMap = getFieldNameMap(semanticParseInfo.getModelId()); + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getLogicSql(), fieldNameMap); + sqlInfo.setLogicSql(sql); } - private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) { - List linking = getLinkingValues(semanticCorrectInfo); + private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) { + List linking = getLinkingValues(semanticParseInfo); if (CollectionUtils.isEmpty(linking)) { return; } @@ -56,13 +55,14 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { Collectors.groupingBy(ElementValue::getFieldValue, Collectors.mapping(ElementValue::getFieldName, Collectors.toSet()))); - String sql = SqlParserReplaceHelper.replaceFieldNameByValue(semanticCorrectInfo.getSql(), - fieldValueToFieldNames); - semanticCorrectInfo.setSql(sql); + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + + String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getLogicSql(), fieldValueToFieldNames); + sqlInfo.setLogicSql(sql); } - private List getLinkingValues(SemanticCorrectInfo semanticCorrectInfo) { - Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT); + private List getLinkingValues(SemanticParseInfo semanticParseInfo) { + Object context = semanticParseInfo.getProperties().get(Constants.CONTEXT); if (Objects.isNull(context)) { return null; } @@ -76,8 +76,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { } - private void updateFieldValueByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) { - List linking = getLinkingValues(semanticCorrectInfo); + private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) { + List linking = getLinkingValues(semanticParseInfo); if (CollectionUtils.isEmpty(linking)) { return; } @@ -90,7 +90,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { (existingValue, newValue) -> newValue) ))); - String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false); - semanticCorrectInfo.setSql(sql); + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getLogicSql(), filedNameToValueMap, false); + sqlInfo.setLogicSql(sql); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java index 3f0c29872..72ebb17c0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -1,7 +1,9 @@ package com.tencent.supersonic.chat.corrector; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; @@ -18,19 +20,18 @@ import org.springframework.util.CollectionUtils; public class GroupByCorrector extends BaseSemanticCorrector { @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { + public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { - super.correct(semanticCorrectInfo); - - addGroupByFields(semanticCorrectInfo); + addGroupByFields(semanticParseInfo); } - private void addGroupByFields(SemanticCorrectInfo semanticCorrectInfo) { - Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel(); + private void addGroupByFields(SemanticParseInfo semanticParseInfo) { + Long modelId = semanticParseInfo.getModel().getModel(); //add dimension group by - String sql = semanticCorrectInfo.getSql(); + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + String logicSql = sqlInfo.getLogicSql(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); //add alias field name Set dimensions = semanticSchema.getDimensions(modelId).stream() @@ -46,7 +47,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { ).collect(Collectors.toSet()); dimensions.add(TimeDimensionEnum.DAY.getChName()); - List selectFields = SqlParserSelectHelper.getSelectFields(sql); + List selectFields = SqlParserSelectHelper.getSelectFields(logicSql); if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) { return; @@ -55,12 +56,12 @@ public class GroupByCorrector extends BaseSemanticCorrector { if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) { return; } - if (SqlParserSelectHelper.hasGroupBy(sql)) { - log.info("not add group by ,exist group by in sql:{}", sql); + if (SqlParserSelectHelper.hasGroupBy(logicSql)) { + log.info("not add group by ,exist group by in logicSql:{}", logicSql); return; } - List aggregateFields = SqlParserSelectHelper.getAggregateFields(sql); + List aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql); Set groupByFields = selectFields.stream() .filter(field -> dimensions.contains(field)) .filter(field -> { @@ -70,16 +71,17 @@ public class GroupByCorrector extends BaseSemanticCorrector { return true; }) .collect(Collectors.toSet()); - semanticCorrectInfo.setSql(SqlParserAddHelper.addGroupBy(sql, groupByFields)); + semanticParseInfo.getSqlInfo().setLogicSql(SqlParserAddHelper.addGroupBy(logicSql, groupByFields)); - addAggregate(semanticCorrectInfo); + addAggregate(semanticParseInfo); } - private void addAggregate(SemanticCorrectInfo semanticCorrectInfo) { - List sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql()); + private void addAggregate(SemanticParseInfo semanticParseInfo) { + List sqlGroupByFields = SqlParserSelectHelper.getGroupByFields( + semanticParseInfo.getSqlInfo().getLogicSql()); if (CollectionUtils.isEmpty(sqlGroupByFields)) { return; } - addAggregateToMetric(semanticCorrectInfo); + addAggregateToMetric(semanticParseInfo); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java index badb1d328..0eee5d670 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -1,7 +1,8 @@ package com.tencent.supersonic.chat.corrector; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.knowledge.service.SchemaService; @@ -14,13 +15,10 @@ import org.springframework.util.CollectionUtils; public class HavingCorrector extends BaseSemanticCorrector { @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - - super.correct(semanticCorrectInfo); + public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { //add aggregate to all metric - semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql()); - Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel(); + Long modelId = semanticParseInfo.getModel().getModel(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); @@ -30,8 +28,8 @@ public class HavingCorrector extends BaseSemanticCorrector { if (CollectionUtils.isEmpty(metrics)) { return; } - String havingSql = SqlParserAddHelper.addHaving(semanticCorrectInfo.getSql(), metrics); - semanticCorrectInfo.setSql(havingSql); + String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getLogicSql(), metrics); + semanticParseInfo.getSqlInfo().setLogicSql(havingSql); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java index e11f166fd..7394dc023 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.corrector; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import java.util.List; import lombok.extern.slf4j.Slf4j; @@ -10,17 +11,16 @@ import org.springframework.util.CollectionUtils; public class SelectCorrector extends BaseSemanticCorrector { @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - super.correct(semanticCorrectInfo); - String sql = semanticCorrectInfo.getSql(); - List aggregateFields = SqlParserSelectHelper.getAggregateFields(sql); - List selectFields = SqlParserSelectHelper.getSelectFields(sql); + public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { + String logicSql = semanticParseInfo.getSqlInfo().getLogicSql(); + List aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql); + List selectFields = SqlParserSelectHelper.getSelectFields(logicSql); // If the number of aggregated fields is equal to the number of queried fields, do not add fields to select. if (!CollectionUtils.isEmpty(aggregateFields) && !CollectionUtils.isEmpty(selectFields) && aggregateFields.size() == selectFields.size()) { return; } - addFieldsToSelect(semanticCorrectInfo, sql); + addFieldsToSelect(semanticParseInfo, logicSql); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index 88fe7283c..987474a17 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaValueMap; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.parser.llm.s2ql.S2QLDateHelper; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; @@ -31,54 +32,52 @@ import org.springframework.util.CollectionUtils; public class WhereCorrector extends BaseSemanticCorrector { @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { + public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { - super.correct(semanticCorrectInfo); + addDateIfNotExist(semanticParseInfo); - addDateIfNotExist(semanticCorrectInfo); + parserDateDiffFunction(semanticParseInfo); - parserDateDiffFunction(semanticCorrectInfo); + addQueryFilter(queryReq, semanticParseInfo); - addQueryFilter(semanticCorrectInfo); - - updateFieldValueByTechName(semanticCorrectInfo); + updateFieldValueByTechName(semanticParseInfo); } - private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) { - String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters()); + private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { + String queryFilter = getQueryFilter(queryReq.getQueryFilters()); - String preSql = semanticCorrectInfo.getSql(); + String logicSql = semanticParseInfo.getSqlInfo().getLogicSql(); if (StringUtils.isNotEmpty(queryFilter)) { - log.info("add queryFilter to preSql :{}", queryFilter); + log.info("add queryFilter to logicSql :{}", queryFilter); Expression expression = null; try { expression = CCJSqlParserUtil.parseCondExpression(queryFilter); } catch (JSQLParserException e) { log.error("parseCondExpression", e); } - String sql = SqlParserAddHelper.addWhere(preSql, expression); - semanticCorrectInfo.setSql(sql); + logicSql = SqlParserAddHelper.addWhere(logicSql, expression); + semanticParseInfo.getSqlInfo().setLogicSql(logicSql); } } - private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) { - String sql = semanticCorrectInfo.getSql(); - sql = SqlParserReplaceHelper.replaceFunction(sql); - semanticCorrectInfo.setSql(sql); + private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) { + String logicSql = semanticParseInfo.getSqlInfo().getLogicSql(); + logicSql = SqlParserReplaceHelper.replaceFunction(logicSql); + semanticParseInfo.getSqlInfo().setLogicSql(logicSql); } - private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) { - String sql = semanticCorrectInfo.getSql(); - List whereFields = SqlParserSelectHelper.getWhereFields(sql); + private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) { + String logicSql = semanticParseInfo.getSqlInfo().getLogicSql(); + List whereFields = SqlParserSelectHelper.getWhereFields(logicSql); if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getChName())) { - String currentDate = S2QLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); + String currentDate = S2QLDateHelper.getReferenceDate(semanticParseInfo.getModelId()); if (StringUtils.isNotBlank(currentDate)) { - sql = SqlParserAddHelper.addParenthesisToWhere(sql); - sql = SqlParserAddHelper.addWhere(sql, TimeDimensionEnum.DAY.getChName(), currentDate); + logicSql = SqlParserAddHelper.addParenthesisToWhere(logicSql); + logicSql = SqlParserAddHelper.addWhere(logicSql, TimeDimensionEnum.DAY.getChName(), currentDate); } } - semanticCorrectInfo.setSql(sql); + semanticParseInfo.getSqlInfo().setLogicSql(logicSql); } private String getQueryFilter(QueryFilters queryFilters) { @@ -95,9 +94,9 @@ public class WhereCorrector extends BaseSemanticCorrector { .collect(Collectors.joining(Constants.AND_UPPER)); } - private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) { + private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId(); + Long modelId = semanticParseInfo.getModel().getId(); List dimensions = semanticSchema.getDimensions().stream() .filter(schemaElement -> modelId.equals(schemaElement.getModel())) .collect(Collectors.toList()); @@ -107,8 +106,9 @@ public class WhereCorrector extends BaseSemanticCorrector { } Map> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions); - String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), aliasAndBizNameToTechName); - semanticCorrectInfo.setSql(sql); + String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getLogicSql(), + aliasAndBizNameToTechName); + semanticParseInfo.getSqlInfo().setLogicSql(logicSql); return; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java index e09e325e0..c50becc54 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java @@ -29,7 +29,7 @@ public abstract class BaseMapper implements SchemaMapper { String simpleName = this.getClass().getSimpleName(); long startTime = System.currentTimeMillis(); - log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo()); + log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches()); try { work(queryContext); @@ -38,7 +38,7 @@ public abstract class BaseMapper implements SchemaMapper { } long cost = System.currentTimeMillis() - startTime; - log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo()); + log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches()); } public abstract void work(QueryContext queryContext); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java index 6be372cf4..241a86238 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java @@ -20,7 +20,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; /** - * match strategy implement + * Base Match Strategy */ @Service @Slf4j @@ -36,7 +36,7 @@ public abstract class BaseMatchStrategy implements MatchStrategy { return null; } - log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds); + log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds); List detects = detect(queryContext, terms, detectModelIds); Map> result = new HashMap<>(); @@ -143,7 +143,7 @@ public abstract class BaseMatchStrategy implements MatchStrategy { return; } for (Term term : terms) { - log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency()); + log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency()); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMResponseService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMResponseService.java index 6174775ae..735d826e8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMResponseService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMResponseService.java @@ -1,194 +1,26 @@ package com.tencent.supersonic.chat.parser.llm.s2ql; -import com.google.common.collect.Lists; import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.chat.corrector.CorrectorService; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.pojo.DateConf; -import com.tencent.supersonic.common.pojo.DateConf.DateMode; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; -import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.tuple.Pair; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; -import org.springframework.util.CollectionUtils; @Slf4j @Service public class LLMResponseService { - - @Autowired - private CorrectorService correctorService; - - public void addParseInfo(QueryContext queryCtx, ParseResult parseResult, String sql, Double weight) { - - SemanticParseInfo parseInfo = getParseInfo(queryCtx, parseResult, weight); - - QueryFilters queryFilters = queryCtx.getRequest().getQueryFilters(); - SemanticCorrectInfo semanticCorrectInfo = correctorService.correctorSql(queryFilters, parseInfo, sql); - - parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql()); - - updateParseInfo(semanticCorrectInfo, parseResult.getModelId(), parseInfo); - } - - private Set getElements(Long modelId, List allFields, List elements) { - return elements.stream() - .filter(schemaElement -> modelId.equals(schemaElement.getModel()) - && allFields.contains(schemaElement.getName()) - ).collect(Collectors.toSet()); - } - - private List getFieldsExceptDate(List allFields) { - if (CollectionUtils.isEmpty(allFields)) { - return new ArrayList<>(); - } - return allFields.stream() - .filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry)) - .collect(Collectors.toList()); - } - - public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) { - - String correctorSql = semanticCorrectInfo.getSql(); - parseInfo.getSqlInfo().setLogicSql(correctorSql); - - List expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); - //set dataInfo - try { - if (!CollectionUtils.isEmpty(expressions)) { - DateConf dateInfo = getDateInfo(expressions); - parseInfo.setDateInfo(dateInfo); - } - } catch (Exception e) { - log.error("set dateInfo error :", e); - } - - //set filter - try { - Map fieldNameToElement = getNameToElement(modelId); - List result = getDimensionFilter(fieldNameToElement, expressions); - parseInfo.getDimensionFilters().addAll(result); - } catch (Exception e) { - log.error("set dimensionFilter error :", e); - } - - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - - if (Objects.isNull(semanticSchema)) { - return; - } - List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql())); - - Set metrics = getElements(modelId, allFields, semanticSchema.getMetrics()); - parseInfo.setMetrics(metrics); - - if (SqlParserSelectFunctionHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) { - parseInfo.setNativeQuery(false); - List groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql()); - List groupByDimensions = getFieldsExceptDate(groupByFields); - parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions())); - } else { - parseInfo.setNativeQuery(true); - List selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql()); - List selectDimensions = getFieldsExceptDate(selectFields); - parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions())); - } - } - - private List getDimensionFilter(Map fieldNameToElement, - List filterExpressions) { - List result = Lists.newArrayList(); - for (FilterExpression expression : filterExpressions) { - QueryFilter dimensionFilter = new QueryFilter(); - dimensionFilter.setValue(expression.getFieldValue()); - SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName()); - if (Objects.isNull(schemaElement)) { - continue; - } - dimensionFilter.setName(schemaElement.getName()); - dimensionFilter.setBizName(schemaElement.getBizName()); - dimensionFilter.setElementID(schemaElement.getId()); - - FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator()); - dimensionFilter.setOperator(operatorEnum); - dimensionFilter.setFunction(expression.getFunction()); - result.add(dimensionFilter); - } - return result; - } - - private DateConf getDateInfo(List filterExpressions) { - List dateExpressions = filterExpressions.stream() - .filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName())) - .collect(Collectors.toList()); - if (CollectionUtils.isEmpty(dateExpressions)) { - return new DateConf(); - } - DateConf dateInfo = new DateConf(); - dateInfo.setDateMode(DateMode.BETWEEN); - FilterExpression firstExpression = dateExpressions.get(0); - - FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator()); - if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) { - dateInfo.setStartDate(firstExpression.getFieldValue().toString()); - dateInfo.setEndDate(firstExpression.getFieldValue().toString()); - dateInfo.setDateMode(DateMode.BETWEEN); - return dateInfo; - } - if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN, - FilterOperatorEnum.GREATER_THAN_EQUALS)) { - dateInfo.setStartDate(firstExpression.getFieldValue().toString()); - if (hasSecondDate(dateExpressions)) { - dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString()); - } - } - if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN, - FilterOperatorEnum.MINOR_THAN_EQUALS)) { - dateInfo.setEndDate(firstExpression.getFieldValue().toString()); - if (hasSecondDate(dateExpressions)) { - dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString()); - } - } - return dateInfo; - } - - private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator, - FilterOperatorEnum... operatorEnums) { - return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); - } - - private boolean hasSecondDate(List dateExpressions) { - return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); - } - - - private SemanticParseInfo getParseInfo(QueryContext queryCtx, ParseResult parseResult, Double weight) { + public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2ql, Double weight) { if (Objects.isNull(weight)) { weight = 0D; } @@ -206,7 +38,7 @@ public class LLMResponseService { parseInfo.setProperties(properties); parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight)); parseInfo.setQueryMode(semanticQuery.getQueryMode()); - parseInfo.getSqlInfo().setS2QL(parseResult.getLlmResp().getSqlOutput()); + parseInfo.getSqlInfo().setS2QL(s2ql); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); Map modelIdToName = semanticSchema.getModelIdToName(); @@ -219,30 +51,4 @@ public class LLMResponseService { queryCtx.getCandidateQueries().add(semanticQuery); return parseInfo; } - - protected Map getNameToElement(Long modelId) { - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - List dimensions = semanticSchema.getDimensions(); - List metrics = semanticSchema.getMetrics(); - - List allElements = Lists.newArrayList(); - allElements.addAll(dimensions); - allElements.addAll(metrics); - //support alias - return allElements.stream() - .filter(schemaElement -> schemaElement.getModel().equals(modelId)) - .flatMap(schemaElement -> { - Set> result = new HashSet<>(); - result.add(Pair.of(schemaElement.getName(), schemaElement)); - List aliasList = schemaElement.getAlias(); - if (!CollectionUtils.isEmpty(aliasList)) { - for (String alias : aliasList) { - result.add(Pair.of(alias, schemaElement)); - } - } - return result.stream(); - }) - .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2)); - } - } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java new file mode 100644 index 000000000..3b4a8b98e --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java @@ -0,0 +1,291 @@ + +package com.tencent.supersonic.chat.query; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.api.component.SemanticInterpreter; +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; +import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; +import com.tencent.supersonic.chat.utils.ComponentFactory; +import com.tencent.supersonic.chat.utils.QueryReqBuilder; +import com.tencent.supersonic.common.pojo.Aggregator; +import com.tencent.supersonic.common.pojo.DateConf; +import com.tencent.supersonic.common.pojo.DateConf.DateMode; +import com.tencent.supersonic.common.pojo.Filter; +import com.tencent.supersonic.common.pojo.Order; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.knowledge.service.SchemaService; +import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; +import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq; +import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.ToString; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; + +@Slf4j +@ToString +public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { + + protected SemanticParseInfo parseInfo = new SemanticParseInfo(); + + protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); + + @Override + public String explain(User user) { + ExplainSqlReq explainSqlReq = null; + SqlInfo sqlInfo = parseInfo.getSqlInfo(); + try { + QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId()); + explainSqlReq = ExplainSqlReq.builder() + .queryTypeEnum(QueryTypeEnum.SQL) + .queryReq(queryS2QLReq) + .build(); + ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user); + if (Objects.nonNull(explain)) { + return explain.getSql(); + } + return explain.getSql(); + } catch (Exception e) { + log.error("explain error explainSqlReq:{}", explainSqlReq, e); + } + return null; + } + + @Override + public SemanticParseInfo getParseInfo() { + return parseInfo; + } + + @Override + public void setParseInfo(SemanticParseInfo parseInfo) { + this.parseInfo = parseInfo; + } + + protected QueryStructReq convertQueryStruct() { + return QueryReqBuilder.buildStructReq(parseInfo); + } + + public void updateParseInfo() { + SqlInfo sqlInfo = parseInfo.getSqlInfo(); + String logicSql = sqlInfo.getLogicSql(); + if (StringUtils.isBlank(logicSql)) { + return; + } + + List expressions = SqlParserSelectHelper.getFilterExpression(logicSql); + //set dataInfo + try { + if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) { + DateConf dateInfo = getDateInfo(expressions); + parseInfo.setDateInfo(dateInfo); + } + } catch (Exception e) { + log.error("set dateInfo error :", e); + } + + //set filter + try { + Map fieldNameToElement = getNameToElement(parseInfo.getModelId()); + List result = getDimensionFilter(fieldNameToElement, expressions); + parseInfo.getDimensionFilters().addAll(result); + } catch (Exception e) { + log.error("set dimensionFilter error :", e); + } + + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + + if (Objects.isNull(semanticSchema)) { + return; + } + List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql())); + + Set metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics()); + parseInfo.setMetrics(metrics); + + if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) { + parseInfo.setNativeQuery(false); + List groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql()); + List groupByDimensions = getFieldsExceptDate(groupByFields); + parseInfo.setDimensions( + getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions())); + } else { + parseInfo.setNativeQuery(true); + List selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql()); + List selectDimensions = getFieldsExceptDate(selectFields); + parseInfo.setDimensions( + getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions())); + } + } + + + private Set getElements(Long modelId, List allFields, List elements) { + return elements.stream() + .filter(schemaElement -> modelId.equals(schemaElement.getModel()) + && allFields.contains(schemaElement.getName()) + ).collect(Collectors.toSet()); + } + + private List getFieldsExceptDate(List allFields) { + if (org.springframework.util.CollectionUtils.isEmpty(allFields)) { + return new ArrayList<>(); + } + return allFields.stream() + .filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry)) + .collect(Collectors.toList()); + } + + + private List getDimensionFilter(Map fieldNameToElement, + List filterExpressions) { + List result = Lists.newArrayList(); + for (FilterExpression expression : filterExpressions) { + QueryFilter dimensionFilter = new QueryFilter(); + dimensionFilter.setValue(expression.getFieldValue()); + SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName()); + if (Objects.isNull(schemaElement)) { + continue; + } + dimensionFilter.setName(schemaElement.getName()); + dimensionFilter.setBizName(schemaElement.getBizName()); + dimensionFilter.setElementID(schemaElement.getId()); + + FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator()); + dimensionFilter.setOperator(operatorEnum); + dimensionFilter.setFunction(expression.getFunction()); + result.add(dimensionFilter); + } + return result; + } + + private DateConf getDateInfo(List filterExpressions) { + List dateExpressions = filterExpressions.stream() + .filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName())) + .collect(Collectors.toList()); + if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) { + return new DateConf(); + } + DateConf dateInfo = new DateConf(); + dateInfo.setDateMode(DateMode.BETWEEN); + FilterExpression firstExpression = dateExpressions.get(0); + + FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator()); + if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) { + dateInfo.setStartDate(firstExpression.getFieldValue().toString()); + dateInfo.setEndDate(firstExpression.getFieldValue().toString()); + dateInfo.setDateMode(DateMode.BETWEEN); + return dateInfo; + } + if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN, + FilterOperatorEnum.GREATER_THAN_EQUALS)) { + dateInfo.setStartDate(firstExpression.getFieldValue().toString()); + if (hasSecondDate(dateExpressions)) { + dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString()); + } + } + if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN, + FilterOperatorEnum.MINOR_THAN_EQUALS)) { + dateInfo.setEndDate(firstExpression.getFieldValue().toString()); + if (hasSecondDate(dateExpressions)) { + dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString()); + } + } + return dateInfo; + } + + private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator, + FilterOperatorEnum... operatorEnums) { + return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); + } + + private boolean hasSecondDate(List dateExpressions) { + return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); + } + + protected Map getNameToElement(Long modelId) { + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + List dimensions = semanticSchema.getDimensions(); + List metrics = semanticSchema.getMetrics(); + + List allElements = Lists.newArrayList(); + allElements.addAll(dimensions); + allElements.addAll(metrics); + //support alias + return allElements.stream() + .filter(schemaElement -> schemaElement.getModel().equals(modelId)) + .flatMap(schemaElement -> { + Set> result = new HashSet<>(); + result.add(Pair.of(schemaElement.getName(), schemaElement)); + List aliasList = schemaElement.getAlias(); + if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) { + for (String alias : aliasList) { + result.add(Pair.of(alias, schemaElement)); + } + } + return result.stream(); + }) + .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2)); + } + + protected void convertBizNameToName(QueryStructReq queryStructReq) { + SchemaService schemaService = ContextUtils.getBean(SchemaService.class); + Map bizNameToName = schemaService.getSemanticSchema() + .getBizNameToName(queryStructReq.getModelId()); + List orders = queryStructReq.getOrders(); + if (CollectionUtils.isNotEmpty(orders)) { + for (Order order : orders) { + order.setColumn(bizNameToName.get(order.getColumn())); + } + } + List aggregators = queryStructReq.getAggregators(); + if (CollectionUtils.isNotEmpty(aggregators)) { + for (Aggregator aggregator : aggregators) { + aggregator.setColumn(bizNameToName.get(aggregator.getColumn())); + } + } + List groups = queryStructReq.getGroups(); + if (CollectionUtils.isNotEmpty(groups)) { + groups = groups.stream().map(group -> bizNameToName.get(group)).collect(Collectors.toList()); + queryStructReq.setGroups(groups); + } + List dimensionFilters = queryStructReq.getDimensionFilters(); + if (CollectionUtils.isNotEmpty(dimensionFilters)) { + dimensionFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); + } + List metricFilters = queryStructReq.getMetricFilters(); + if (CollectionUtils.isNotEmpty(dimensionFilters)) { + metricFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); + } + queryStructReq.setModelName(parseInfo.getModelName()); + } + + protected void initS2SqlByStruct() { + QueryStructReq queryStructReq = convertQueryStruct(); + convertBizNameToName(queryStructReq); + QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq); + parseInfo.getSqlInfo().setS2QL(queryS2QLReq.getSql()); + parseInfo.getSqlInfo().setLogicSql(queryS2QLReq.getSql()); + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/interpret/MetricInterpretQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/interpret/MetricInterpretQuery.java index 65b85768a..579664162 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/interpret/MetricInterpretQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/interpret/MetricInterpretQuery.java @@ -10,7 +10,6 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; import com.tencent.supersonic.chat.config.OptimizationConfig; -import com.tencent.supersonic.chat.corrector.CorrectorService; import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; @@ -30,7 +29,6 @@ import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -42,8 +40,6 @@ public class MetricInterpretQuery extends PluginSemanticQuery { public static final String QUERY_MODE = "METRIC_INTERPRET"; - @Autowired - private CorrectorService correctorService; public MetricInterpretQuery() { QueryManager.register(this); @@ -56,15 +52,13 @@ public class MetricInterpretQuery extends PluginSemanticQuery { @Override public QueryResult execute(User user) throws SqlParseException { - QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo); - fillAggregator(queryStructReq, parseInfo.getMetrics()); - queryStructReq.setNativeQuery(true); + QueryStructReq queryStructReq = convertQueryStruct(); SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); if (optimizationConfig.isUseS2qlSwitch()) { - correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo); + queryStructReq.setS2QL(parseInfo.getSqlInfo().getS2QL()); + queryStructReq.setS2QL(parseInfo.getSqlInfo().getQuerySql()); } QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user); @@ -87,6 +81,18 @@ public class MetricInterpretQuery extends PluginSemanticQuery { return queryResult; } + @Override + public void initS2Sql(User user) { + initS2SqlByStruct(); + } + + protected QueryStructReq convertQueryStruct() { + QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo); + fillAggregator(queryStructReq, parseInfo.getMetrics()); + queryStructReq.setNativeQuery(true); + return queryStructReq; + } + private String replaceText(String text, List schemaElementMatches, Map replacedMap) { if (CollectionUtils.isEmpty(schemaElementMatches)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/S2QLQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/S2QLQuery.java index b21fbcba1..1225c366c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/S2QLQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/S2QLQuery.java @@ -10,10 +10,7 @@ import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.common.pojo.QueryColumn; -import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; -import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; -import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq; import java.util.ArrayList; import java.util.List; @@ -65,22 +62,9 @@ public class S2QLQuery extends PluginSemanticQuery { return queryResult; } - @Override - public SqlInfo explain(User user) { + public void initS2Sql(User user) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); - ExplainSqlReq explainSqlReq = null; - try { - QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId()); - explainSqlReq = ExplainSqlReq.builder() - .queryTypeEnum(QueryTypeEnum.SQL) - .queryReq(queryS2QLReq) - .build(); - ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user); - sqlInfo.setQuerySql(explain.getSql()); - } catch (Exception e) { - log.error("explain error explainSqlReq:{}", explainSqlReq, e); - } - return sqlInfo; + sqlInfo.setLogicSql(sqlInfo.getS2QL()); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java index 9da892bef..5dda72c1f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java @@ -1,26 +1,19 @@ package com.tencent.supersonic.chat.query.plugin; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.chat.api.component.SemanticQuery; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; +import com.tencent.supersonic.chat.query.BaseSemanticQuery; import lombok.extern.slf4j.Slf4j; @Slf4j -public abstract class PluginSemanticQuery implements SemanticQuery { +public abstract class PluginSemanticQuery extends BaseSemanticQuery { - protected SemanticParseInfo parseInfo = new SemanticParseInfo(); - - public void setParseInfo(SemanticParseInfo parseInfo) { - this.parseInfo = parseInfo; - } - - public SemanticParseInfo getParseInfo() { - return parseInfo; + @Override + public String explain(User user) { + return null; } @Override - public SqlInfo explain(User user) { - return null; + public void initS2Sql(User user) { + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java index ed1a469a1..afe0df114 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java @@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.query.rule; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; -import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ModelSchema; import com.tencent.supersonic.chat.api.pojo.QueryContext; @@ -14,9 +13,8 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; -import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.config.OptimizationConfig; -import com.tencent.supersonic.chat.corrector.CorrectorService; +import com.tencent.supersonic.chat.query.BaseSemanticQuery; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.utils.ComponentFactory; @@ -24,14 +22,9 @@ import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; -import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; -import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; -import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; -import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -44,9 +37,8 @@ import org.apache.commons.lang3.StringUtils; @Slf4j @ToString -public abstract class RuleSemanticQuery implements SemanticQuery, Serializable { +public abstract class RuleSemanticQuery extends BaseSemanticQuery { - protected SemanticParseInfo parseInfo = new SemanticParseInfo(); protected QueryMatcher queryMatcher = new QueryMatcher(); protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); @@ -59,6 +51,11 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable { return queryMatcher.match(candidateElementMatches); } + @Override + public void initS2Sql(User user) { + initS2SqlByStruct(); + } + public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) { parseInfo.setQueryMode(getQueryMode()); @@ -203,10 +200,9 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable { QueryStructReq queryStructReq = convertQueryStruct(); OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); if (optimizationConfig.isUseS2qlSwitch()) { - CorrectorService correctorService = ContextUtils.getBean(CorrectorService.class); - correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo); + queryStructReq.setS2QL(parseInfo.getSqlInfo().getS2QL()); + queryStructReq.setLogicSql(parseInfo.getSqlInfo().getLogicSql()); } QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user); @@ -227,29 +223,6 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable { return queryResult; } - - @Override - public SqlInfo explain(User user) { - SqlInfo sqlInfo = parseInfo.getSqlInfo(); - ExplainSqlReq explainSqlReq = null; - try { - QueryStructReq queryStructReq = convertQueryStruct(); - CorrectorService correctorService = ContextUtils.getBean(CorrectorService.class); - correctorService.addS2QLAndLoginSql(queryStructReq, parseInfo); - - QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId()); - explainSqlReq = ExplainSqlReq.builder() - .queryTypeEnum(QueryTypeEnum.SQL) - .queryReq(queryS2QLReq) - .build(); - ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user); - sqlInfo.setQuerySql(explain.getSql()); - } catch (Exception e) { - log.error("explain error explainSqlReq:{}", explainSqlReq, e); - } - return sqlInfo; - } - protected boolean isMultiStructQuery() { return false; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java index a4dc5a265..808def88c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java @@ -6,7 +6,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseTimeCostDO; -import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.common.util.JsonUtil; @@ -15,6 +14,7 @@ import java.util.Map; import java.util.Objects; import java.util.function.Function; import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; public class SqlInfoParseResponder implements ParseResponder { @@ -64,11 +64,11 @@ public class SqlInfoParseResponder implements ParseResponder { return; } semanticQuery.setParseInfo(parseInfo); - SqlInfo sqlInfo = semanticQuery.explain(queryReq.getUser()); - if (Objects.isNull(sqlInfo)) { + String explainSql = semanticQuery.explain(queryReq.getUser()); + if (StringUtils.isBlank(explainSql)) { return; } - parseInfo.setSqlInfo(sqlInfo); + parseInfo.getSqlInfo().setQuerySql(explainSql); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java index e31edda02..3166702f2 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java @@ -30,7 +30,6 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.MetricInfo; import com.tencent.supersonic.chat.api.pojo.response.ModelInfo; import com.tencent.supersonic.chat.config.AggregatorConfig; -import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.common.pojo.DateConf; @@ -427,9 +426,6 @@ public class SemanticService { queryStructReq.setGroups(new ArrayList<>(Arrays.asList(dateField))); queryStructReq.setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, results)); - OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); - QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user); if (Objects.nonNull(queryResp) && !CollectionUtils.isEmpty(queryResp.getResultList())) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 1fadbba58..c22ae344a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.service.impl; import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SchemaMapper; +import com.tencent.supersonic.chat.api.component.SemanticCorrector; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticQuery; @@ -20,8 +21,6 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; -import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; -import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; @@ -112,14 +111,18 @@ public class QueryServiceImpl implements QueryService { private QuerySelector querySelector = ComponentFactory.getQuerySelector(); private List parseResponders = ComponentFactory.getParseResponders(); private List executeResponders = ComponentFactory.getExecuteResponders(); + private List semanticCorrectors = ComponentFactory.getSqlCorrections(); @Override public ParseResp performParsing(QueryReq queryReq) { Long parseTime = System.currentTimeMillis(); + //1. build queryContext and chatContext QueryContext queryCtx = new QueryContext(queryReq); // in order to support multi-turn conversation, chat context is needed ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId()); List timeCostDOList = new ArrayList<>(); + + //2. mapper schemaMappers.stream().forEach(mapper -> { Long startTime = System.currentTimeMillis(); mapper.map(queryCtx); @@ -127,6 +130,8 @@ public class QueryServiceImpl implements QueryService { .interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build()); log.info("{} result:{}", mapper.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); }); + + //3. parser semanticParsers.stream().forEach(parser -> { Long startTime = System.currentTimeMillis(); parser.parse(queryCtx, chatCtx); @@ -134,17 +139,31 @@ public class QueryServiceImpl implements QueryService { .interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build()); log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); }); + + //4. corrector + List candidateQueries = queryCtx.getCandidateQueries(); + if (CollectionUtils.isNotEmpty(candidateQueries)) { + for (SemanticQuery semanticQuery : candidateQueries) { + semanticQuery.initS2Sql(queryReq.getUser()); + semanticCorrectors.stream().forEach(correction -> { + correction.correct(queryReq, semanticQuery.getParseInfo()); + }); + semanticQuery.updateParseInfo(); + } + } + + //5. generate parsing results. ParseResp parseResult; List chatParseDOS = Lists.newArrayList(); - if (queryCtx.getCandidateQueries().size() > 0) { - log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect( + if (candidateQueries.size() > 0) { + log.debug("pick before [{}]", candidateQueries.stream().collect( Collectors.toList())); - List selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq); + List selectedQueries = querySelector.select(candidateQueries, queryReq); log.debug("pick after [{}]", selectedQueries.stream().collect( Collectors.toList())); List selectedParses = convertParseInfo(selectedQueries); - List candidateParses = convertParseInfo(queryCtx.getCandidateQueries()); + List candidateParses = convertParseInfo(candidateQueries); candidateParses = getTop5CandidateParseInfo(selectedParses, candidateParses); parseResult = ParseResp.builder() .chatId(queryReq.getChatId()) @@ -161,6 +180,7 @@ public class QueryServiceImpl implements QueryService { .state(ParseResp.ParseState.FAILED) .build(); } + //6. responders for (ParseResponder parseResponder : parseResponders) { Long startTime = System.currentTimeMillis(); parseResponder.fillResponse(parseResult, queryCtx, chatParseDOS); @@ -315,14 +335,14 @@ public class QueryServiceImpl implements QueryService { Set removeWhereFieldNames = new HashSet<>(); Set removeHavingFieldNames = new HashSet<>(); // replace where filter - updateFilters(filedNameToValueMap, whereExpressionList, queryData.getDimensionFilters(), + updateFilters(whereExpressionList, queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames); updateDateInfo(queryData, parseInfo, filedNameToValueMap, whereExpressionList, addWhereConditions, removeWhereFieldNames); correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames); // replace having filter - updateFilters(havingFiledNameToValueMap, havingExpressionList, queryData.getDimensionFilters(), + updateFilters(havingExpressionList, queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames); correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); correctorSql = SqlParserRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); @@ -333,9 +353,9 @@ public class QueryServiceImpl implements QueryService { log.info("correctorSql after replacing:{}", correctorSql); parseInfo.getSqlInfo().setLogicSql(correctorSql); semanticQuery.setParseInfo(parseInfo); - SqlInfo sqlInfo = semanticQuery.explain(user); - if (!Objects.isNull(sqlInfo)) { - parseInfo.setSqlInfo(sqlInfo); + String explainSql = semanticQuery.explain(user); + if (StringUtils.isNotBlank(explainSql)) { + parseInfo.getSqlInfo().setQuerySql(explainSql); } } semanticQuery.setParseInfo(parseInfo); @@ -432,8 +452,7 @@ public class QueryServiceImpl implements QueryService { addConditions.add(comparisonExpression); } - private void updateFilters(Map> filedNameToValueMap, - List filterExpressionList, + private void updateFilters(List filterExpressionList, Set metricFilters, Set contextMetricFilters, List addConditions, @@ -624,11 +643,7 @@ public class QueryServiceImpl implements QueryService { groups.add(dimensionValueReq.getBizName()); queryStructReq.setGroups(groups); SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); - - OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); - QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user); - return queryResultWithSchemaResp; + return semanticInterpreter.queryByStruct(queryStructReq, user); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/DictQueryHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/DictQueryHelper.java index 4149f6b14..f37682dbf 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/DictQueryHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/DictQueryHelper.java @@ -10,7 +10,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.config.DefaultMetric; import com.tencent.supersonic.chat.config.Dim4Dict; -import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DateConf; @@ -19,7 +18,6 @@ import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.util.ArrayList; @@ -52,13 +50,10 @@ public class DictQueryHelper { public List fetchDimValueSingle(Long modelId, DefaultMetric defaultMetricDesc, Dim4Dict dim4Dict, - User user) { + User user) { List data = new ArrayList<>(); QueryStructReq queryStructCmd = generateQueryStructCmd(modelId, defaultMetricDesc, dim4Dict); try { - OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - queryStructCmd.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); - QueryResultWithSchemaResp queryResultWithColumns = semanticInterpreter.queryByStruct(queryStructCmd, user); log.info("fetchDimValueSingle sql:{}", queryResultWithColumns.getSql()); @@ -100,7 +95,7 @@ public class DictQueryHelper { } private List generateFileData(List> resultList, String nature, String dimName, - String metricName, Dim4Dict dim4Dict) { + String metricName, Dim4Dict dim4Dict) { List data = new ArrayList<>(); if (CollectionUtils.isEmpty(resultList)) { return data; @@ -125,7 +120,7 @@ public class DictQueryHelper { } private void constructDataLines(Map valueAndFrequencyPair, String nature, - List data, Dim4Dict dim4Dict) { + List data, Dim4Dict dim4Dict) { valueAndFrequencyPair.forEach((dimValue, metric) -> { if (metric > MAX_FREQUENCY) { metric = MAX_FREQUENCY; diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/QueryStructReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/QueryStructReq.java index dd7ce4395..997f3e05c 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/QueryStructReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/QueryStructReq.java @@ -59,8 +59,6 @@ public class QueryStructReq { private Boolean nativeQuery = false; private Cache cacheInfo; - private boolean useS2qlSwitch; - /** * Later deleted for compatibility only */