[improvement][chat]Make the total # of candidates to the user configurable.

This commit is contained in:
jerryjzhang
2024-10-30 10:16:08 +08:00
parent 29489b4669
commit 4102082f2a

View File

@@ -45,6 +45,7 @@ import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SHOW_COUNT;
@Slf4j @Slf4j
public class NL2SQLParser implements ChatQueryParser { public class NL2SQLParser implements ChatQueryParser {
@@ -86,6 +87,7 @@ public class NL2SQLParser implements ChatQueryParser {
// for every requested dataSet, recursively invoke rule-based parser with different // for every requested dataSet, recursively invoke rule-based parser with different
// mapModes // mapModes
Set<Long> requestedDatasets = queryNLReq.getDataSetIds(); Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
List<SemanticParseInfo> candidateParses = Lists.newArrayList();
for (Long datasetId : requestedDatasets) { for (Long datasetId : requestedDatasets) {
queryNLReq.setDataSetIds(Collections.singleton(datasetId)); queryNLReq.setDataSetIds(Collections.singleton(datasetId));
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId()); ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
@@ -98,21 +100,28 @@ public class NL2SQLParser implements ChatQueryParser {
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE); queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
doParse(queryNLReq, parseResp); doParse(queryNLReq, parseResp);
} }
// for one dataset select the most suitable parses // for one dataset select the top 1 parse after sorting
List<SemanticParseInfo> sortedParses = parseResp.getSelectedParses().stream() SemanticParseInfo.sort(parseResp.getSelectedParses());
.sorted(new SemanticParseInfo.SemanticParseComparator()).limit(1) candidateParses.add(parseResp.getSelectedParses().get(0));
.collect(Collectors.toList());
parseContext.getResponse().getSelectedParses().addAll(sortedParses);
} }
SemanticParseInfo.sort(parseContext.getResponse().getSelectedParses()); ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int parserShowCount =
Integer.parseInt(parserConfig.getParameterValue(PARSER_SHOW_COUNT));
SemanticParseInfo.sort(candidateParses);
parseContext.getResponse().setSelectedParses(
candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size())));
} }
// next go with llm-based parsers unless LLM is disabled or use feedback is needed. // next go with llm-based parsers unless LLM is disabled or use feedback is needed.
if (parseContext.needLLMParse() && !parseContext.needFeedback()) { if (parseContext.needLLMParse() && !parseContext.needFeedback()) {
// either the user or the system selects one parse from the candidate parses.
if (Objects.isNull(parseContext.getRequest().getSelectedParse())
&& parseContext.getResponse().getSelectedParses().isEmpty()) {
return;
}
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
queryNLReq.setText2SQLType(Text2SQLType.LLM_OR_RULE); queryNLReq.setText2SQLType(Text2SQLType.LLM_OR_RULE);
// either the user or the system selects one parse from the candidate parses.
SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse(); SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse();
queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse
: parseContext.getResponse().getSelectedParses().get(0)); : parseContext.getResponse().getSelectedParses().get(0));