From 400b9f86f00a98774adc35c8810bf4babb864ac7 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Tue, 29 Oct 2024 13:07:02 +0800 Subject: [PATCH] [improvement][chat]Modify core workflow of NL2SQLParser, always invoking rule-based parsers first.#1729 --- .../chat/api/pojo/response/ChatParseResp.java | 1 + .../chat/server/parser/NL2SQLParser.java | 58 ++++++++++-------- .../chat/server/pojo/ParseContext.java | 4 -- .../parse/ParseInfoSortProcessor.java | 59 +++++++++++++++---- .../chat/server/util/QueryReqConverter.java | 2 +- .../common/pojo/enums/Text2SQLType.java | 6 +- .../headless/api/pojo/request/QueryNLReq.java | 2 +- .../headless/chat/utils/QueryReqBuilder.java | 15 ++--- .../tencent/supersonic/util/DataUtils.java | 1 + 9 files changed, 96 insertions(+), 52 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ChatParseResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ChatParseResp.java index 04e37bbfd..42bc53d4d 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ChatParseResp.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ChatParseResp.java @@ -22,4 +22,5 @@ public class ChatParseResp { public ChatParseResp(Long queryId) { this.queryId = queryId; } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 557c4b465..4762304fe 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -36,12 +36,7 @@ import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; +import java.util.*; import java.util.stream.Collectors; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; @@ -78,29 +73,46 @@ public class NL2SQLParser implements ChatQueryParser { return; } - QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); - ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); - ChatContext chatCtx = - chatContextService.getOrCreateContext(parseContext.getRequest().getChatId()); - if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) { - queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); - } - - if (parseContext.needRuleParse()) { + // first go with rule-based parsers unless the user has already selected one parse. + if (Objects.isNull(parseContext.getRequest().getSelectedParse())) { + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); - ChatParseResp parseResp = parseContext.getResponse(); - for (MapModeEnum mode : MapModeEnum.values()) { - queryNLReq.setMapModeEnum(mode); - doParse(queryNLReq, parseResp); + + // inject semantic parse saved by in the chat context + ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); + ChatContext chatCtx = + chatContextService.getOrCreateContext(parseContext.getRequest().getChatId()); + if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) { + queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); + } + + // for every requested dataSet, recursively invoke rule-based parser + // with different mapModes, unless any valid semantic parse is derived. + Set requestedDatasets = queryNLReq.getDataSetIds(); + for (Long datasetId : requestedDatasets) { + queryNLReq.setDataSetIds(Collections.singleton(datasetId)); + ChatParseResp parseResp = parseContext.getResponse(); + for (MapModeEnum mode : MapModeEnum.values()) { + queryNLReq.setMapModeEnum(mode); + doParse(queryNLReq, parseResp); + if (!parseResp.getSelectedParses().isEmpty()) { + break; + } + } } } + // next go with llm-based parsers unless LLM is disabled or use feedback is needed. if (parseContext.needLLMParse() && !parseContext.needFeedback()) { - SemanticParseInfo selectedParse = parseContext.getRequest().getSelectedParse(); - queryNLReq.setSelectedParseInfo(Objects.nonNull(selectedParse) ? selectedParse + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); + queryNLReq.setText2SQLType(Text2SQLType.LLM_OR_RULE); + + // either the user or the system selects one parse from the candidate parses. + SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse(); + queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse : parseContext.getResponse().getSelectedParses().get(0)); - queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); - parseContext.getResponse().getSelectedParses().clear(); + + parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId())); rewriteMultiTurn(parseContext, queryNLReq); addDynamicExemplars(parseContext, queryNLReq); doParse(queryNLReq, parseContext.getResponse()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index af89bfe4a..5e97c83d2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -31,10 +31,6 @@ public class ParseContext { && response.getSelectedParses().size() > 1); } - public boolean needRuleParse() { - return Objects.isNull(request.getSelectedParse()); - } - public boolean needLLMParse() { return enableLLM() && (Objects.nonNull(request.getSelectedParse()) || !response.getSelectedParses().isEmpty()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java index 2004c57b2..e8d149ae3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java @@ -1,30 +1,67 @@ package com.tencent.supersonic.chat.server.processor.parse; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import com.tencent.supersonic.chat.server.pojo.ParseContext; +import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult; +import lombok.extern.slf4j.Slf4j; import java.util.*; /** * ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \ **/ +@Slf4j public class ParseInfoSortProcessor implements ParseResultProcessor { @Override public void process(ParseContext parseContext) { - Set parseInfoText = Sets.newHashSet(); - List sortedParseInfo = Lists.newArrayList(); + List selectedParses = parseContext.getResponse().getSelectedParses(); - parseContext.getResponse().getSelectedParses().forEach(p -> { - if (!parseInfoText.contains(p.getTextInfo())) { - sortedParseInfo.add(p); - parseInfoText.add(p.getTextInfo()); + selectedParses.sort((o1, o2) -> { + DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches()); + DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches()); + + double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity(); + if (difference == 0) { + difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity(); + if (difference == 0) { + difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity(); + } + if (difference == 0) { + difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt(); + } } + return difference >= 0 ? -1 : 1; }); - - sortedParseInfo.sort((o1, o2) -> o1.getScore() - o2.getScore() > 0 ? 1 : 0); - parseContext.getResponse().setSelectedParses(sortedParseInfo); + // re-assign parseId + for (int i = 0; i < selectedParses.size(); i++) { + SemanticParseInfo parseInfo = selectedParses.get(i); + parseInfo.setId(i + 1); + } } + + private DataSetMatchResult getDataSetMatchResult(List elementMatches) { + double maxMetricSimilarity = 0; + double maxDatasetSimilarity = 0; + double totalSimilarity = 0; + long maxMetricUseCnt = 0L; + for (SchemaElementMatch match : elementMatches) { + if (SchemaElementType.DATASET.equals(match.getElement().getType())) { + maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity()); + } + if (SchemaElementType.METRIC.equals(match.getElement().getType())) { + maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity()); + if (Objects.nonNull(match.getElement().getUseCnt())) { + maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt()); + } + } + totalSimilarity += match.getSimilarity(); + } + return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity) + .maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity) + .build(); + } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 569f84f71..4954f6ee2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -17,7 +17,7 @@ public class QueryReqConverter { QueryNLReq queryNLReq = new QueryNLReq(); BeanMapper.mapper(parseContext.getRequest(), queryNLReq); queryNLReq.setText2SQLType( - parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE); + parseContext.enableLLM() ? Text2SQLType.LLM_OR_RULE : Text2SQLType.ONLY_RULE); queryNLReq.setDataSetIds(getDataSetIds(parseContext)); queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig()); queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse()); diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java index cd965f292..f38ed6efb 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java @@ -1,13 +1,13 @@ package com.tencent.supersonic.common.pojo.enums; public enum Text2SQLType { - ONLY_RULE, ONLY_LLM, RULE_AND_LLM; + ONLY_RULE, ONLY_LLM, LLM_OR_RULE; public boolean enableRule() { - return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM); + return this.equals(ONLY_RULE) || this.equals(LLM_OR_RULE); } public boolean enableLLM() { - return this.equals(ONLY_LLM) || this.equals(RULE_AND_LLM); + return this.equals(ONLY_LLM) || this.equals(LLM_OR_RULE); } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java index a0bbe1e7f..83116f642 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java @@ -22,7 +22,7 @@ public class QueryNLReq extends SemanticQueryReq { private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; - private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; + private Text2SQLType text2SQLType = Text2SQLType.LLM_OR_RULE; private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private QueryDataType queryDataType = QueryDataType.ALL; private Map chatAppConfig; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java index 0289245a4..cb1dd57b3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java @@ -60,15 +60,12 @@ public class QueryReqBuilder { queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName) .collect(Collectors.toList())); queryStructReq.setLimit(parseInfo.getLimit()); - // only one metric is queried at once - Set metrics = parseInfo.getMetrics(); - if (!CollectionUtils.isEmpty(metrics)) { - SchemaElement metricElement = parseInfo.getMetrics().iterator().next(); - Set order = - getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement); - queryStructReq - .setAggregators(getAggregatorByMetric(parseInfo.getAggType(), metricElement)); - queryStructReq.setOrders(new ArrayList<>(order)); + + for (SchemaElement metricElement : parseInfo.getMetrics()) { + queryStructReq.getAggregators() + .addAll(getAggregatorByMetric(parseInfo.getAggType(), metricElement)); + queryStructReq.setOrders(new ArrayList<>( + getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement))); } deletionDuplicated(queryStructReq); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java index 589265793..4e28f61a3 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java @@ -40,6 +40,7 @@ public class DataUtils { public static ChatParseReq getChatParseReq(Integer id, String query, boolean enableLLM) { ChatParseReq chatParseReq = new ChatParseReq(); chatParseReq.setQueryText(query); + chatParseReq.setAgentId(metricAgentId); chatParseReq.setChatId(id); chatParseReq.setUser(user_test); chatParseReq.setDisableLLM(!enableLLM);