From 4102082f2ae31be78cde12146019b2f9f546b94c Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Wed, 30 Oct 2024 10:16:08 +0800 Subject: [PATCH] [improvement][chat]Make the total # of candidates to the user configurable. --- .../chat/server/parser/NL2SQLParser.java | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 8f84de6a2..631fd9595 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -45,6 +45,7 @@ import java.util.Set; import java.util.stream.Collectors; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; +import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SHOW_COUNT; @Slf4j public class NL2SQLParser implements ChatQueryParser { @@ -86,6 +87,7 @@ public class NL2SQLParser implements ChatQueryParser { // for every requested dataSet, recursively invoke rule-based parser with different // mapModes Set requestedDatasets = queryNLReq.getDataSetIds(); + List candidateParses = Lists.newArrayList(); for (Long datasetId : requestedDatasets) { queryNLReq.setDataSetIds(Collections.singleton(datasetId)); ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId()); @@ -98,21 +100,28 @@ public class NL2SQLParser implements ChatQueryParser { queryNLReq.setMapModeEnum(MapModeEnum.LOOSE); doParse(queryNLReq, parseResp); } - // for one dataset select the most suitable parses - List sortedParses = parseResp.getSelectedParses().stream() - .sorted(new SemanticParseInfo.SemanticParseComparator()).limit(1) - .collect(Collectors.toList()); - parseContext.getResponse().getSelectedParses().addAll(sortedParses); + // for one dataset select the top 1 parse after sorting + SemanticParseInfo.sort(parseResp.getSelectedParses()); + candidateParses.add(parseResp.getSelectedParses().get(0)); } - SemanticParseInfo.sort(parseContext.getResponse().getSelectedParses()); + ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); + int parserShowCount = + Integer.parseInt(parserConfig.getParameterValue(PARSER_SHOW_COUNT)); + SemanticParseInfo.sort(candidateParses); + parseContext.getResponse().setSelectedParses( + candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size()))); } // next go with llm-based parsers unless LLM is disabled or use feedback is needed. if (parseContext.needLLMParse() && !parseContext.needFeedback()) { + // either the user or the system selects one parse from the candidate parses. + if (Objects.isNull(parseContext.getRequest().getSelectedParse()) + && parseContext.getResponse().getSelectedParses().isEmpty()) { + return; + } + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); queryNLReq.setText2SQLType(Text2SQLType.LLM_OR_RULE); - - // either the user or the system selects one parse from the candidate parses. SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse(); queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse : parseContext.getResponse().getSelectedParses().get(0));