From 414aaaa0b6262e5ffc933d8c49db9797c829f80e Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Tue, 29 Oct 2024 00:27:24 +0800 Subject: [PATCH] [improvement][chat]Modify core workflow of `NL2SQLParser`, always invoking rule-based parsers first.#1729 --- .../chat/server/parser/NL2SQLParser.java | 36 ++++++++++--------- .../chat/server/pojo/ParseContext.java | 23 ++++++++---- .../parse/ParseInfoSortProcessor.java | 2 +- .../service/impl/ChatQueryServiceImpl.java | 7 ---- .../chat/parser/rule/RuleSqlParser.java | 6 ++++ 5 files changed, 42 insertions(+), 32 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index bae33b3b4..557c4b465 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -11,6 +11,7 @@ import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.enums.AppModule; +import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ContextUtils; @@ -73,11 +74,7 @@ public class NL2SQLParser implements ChatQueryParser { @Override public void parse(ParseContext parseContext) { - if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) { - return; - } - if (parseContext.needFeedback()) { - processFeedback(parseContext); + if (!parseContext.enableNL2SQL()) { return; } @@ -88,20 +85,25 @@ public class NL2SQLParser implements ChatQueryParser { if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) { queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); } - if (parseContext.enableLLM()) { - rewriteMultiTurn(parseContext, queryNLReq); - addDynamicExemplars(parseContext, queryNLReq); + + if (parseContext.needRuleParse()) { + queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); + ChatParseResp parseResp = parseContext.getResponse(); + for (MapModeEnum mode : MapModeEnum.values()) { + queryNLReq.setMapModeEnum(mode); + doParse(queryNLReq, parseResp); + } } - doParse(queryNLReq, parseContext.getResponse()); - } - - private void processFeedback(ParseContext parseContext) { - QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); - ChatParseResp parseResp = parseContext.getResponse(); - for (MapModeEnum mode : MapModeEnum.values()) { - queryNLReq.setMapModeEnum(mode); - doParse(queryNLReq, parseResp); + if (parseContext.needLLMParse() && !parseContext.needFeedback()) { + SemanticParseInfo selectedParse = parseContext.getRequest().getSelectedParse(); + queryNLReq.setSelectedParseInfo(Objects.nonNull(selectedParse) ? selectedParse + : parseContext.getResponse().getSelectedParses().get(0)); + queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); + parseContext.getResponse().getSelectedParses().clear(); + rewriteMultiTurn(parseContext, queryNLReq); + addDynamicExemplars(parseContext, queryNLReq); + doParse(queryNLReq, parseContext.getResponse()); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index 23c19e95b..af89bfe4a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.pojo; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp; import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import lombok.Data; import java.util.Objects; @@ -20,14 +19,24 @@ public class ParseContext { } public boolean enableNL2SQL() { - return agent.containsDatasetTool(); - } - - public boolean needFeedback() { - return agent.enableFeedback() && Objects.isNull(request.getSelectedParse()); + return Objects.nonNull(agent) && agent.containsDatasetTool(); } public boolean enableLLM() { - return !(needFeedback() || request.isDisableLLM()); + return !request.isDisableLLM(); + } + + public boolean needFeedback() { + return agent.enableFeedback() && (Objects.isNull(request.getSelectedParse()) + && response.getSelectedParses().size() > 1); + } + + public boolean needRuleParse() { + return Objects.isNull(request.getSelectedParse()); + } + + public boolean needLLMParse() { + return enableLLM() && (Objects.nonNull(request.getSelectedParse()) + || !response.getSelectedParses().isEmpty()); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java index c0589b205..2004c57b2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java @@ -24,7 +24,7 @@ public class ParseInfoSortProcessor implements ParseResultProcessor { } }); - Collections.sort(sortedParseInfo, (o1, o2) -> o1.getScore() - o2.getScore() >= 0 ? 1 : 0); + sortedParseInfo.sort((o1, o2) -> o1.getScore() - o2.getScore() > 0 ? 1 : 0); parseContext.getResponse().setSelectedParses(sortedParseInfo); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 61afeee5c..78571dc68 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -119,13 +119,6 @@ public class ChatQueryServiceImpl implements ChatQueryService { chatManageService.updateParseCostTime(parseContext.getResponse()); } - // no need for explicit user feedback if there is only one candidate parses - if (parseContext.needFeedback() && parseContext.getResponse().getSelectedParses().size() == 1) { - chatParseReq.setQueryId(parseContext.getResponse().getQueryId()); - chatParseReq.setSelectedParse(parseContext.getResponse().getSelectedParses().get(0)); - return parse(chatParseReq); - } - return parseContext.getResponse(); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index 79e0408fc..19a7ce26f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -1,9 +1,11 @@ package com.tencent.supersonic.headless.chat.parser.rule; +import com.google.common.collect.Lists; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.SemanticParser; +import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import lombok.extern.slf4j.Slf4j; @@ -27,6 +29,7 @@ public class RuleSqlParser implements SemanticParser { return; } SchemaMapInfo mapInfo = chatQueryContext.getMapInfo(); + List candidateQueries = Lists.newArrayList(); // iterate all schemaElementMatches to resolve query mode for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) { List elementMatches = mapInfo.getMatchedElements(dataSetId); @@ -36,7 +39,10 @@ public class RuleSqlParser implements SemanticParser { query.fillParseInfo(chatQueryContext); chatQueryContext.getCandidateQueries().add(query); } + candidateQueries.addAll(chatQueryContext.getCandidateQueries()); + chatQueryContext.getCandidateQueries().clear(); } + chatQueryContext.setCandidateQueries(candidateQueries); auxiliaryParsers.forEach(p -> p.parse(chatQueryContext)); }