[improvement][chat]Modify core workflow of NL2SQLParser, always invoking rule-based parsers first.#1729

This commit is contained in:
jerryjzhang
2024-10-29 00:27:24 +08:00
parent 5d9b1b917e
commit 414aaaa0b6
5 changed files with 42 additions and 32 deletions

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -73,11 +74,7 @@ public class NL2SQLParser implements ChatQueryParser {
@Override
public void parse(ParseContext parseContext) {
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
return;
}
if (parseContext.needFeedback()) {
processFeedback(parseContext);
if (!parseContext.enableNL2SQL()) {
return;
}
@@ -88,20 +85,25 @@ public class NL2SQLParser implements ChatQueryParser {
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
if (parseContext.enableLLM()) {
rewriteMultiTurn(parseContext, queryNLReq);
addDynamicExemplars(parseContext, queryNLReq);
if (parseContext.needRuleParse()) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
ChatParseResp parseResp = parseContext.getResponse();
for (MapModeEnum mode : MapModeEnum.values()) {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
}
}
doParse(queryNLReq, parseContext.getResponse());
}
private void processFeedback(ParseContext parseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
ChatParseResp parseResp = parseContext.getResponse();
for (MapModeEnum mode : MapModeEnum.values()) {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
if (parseContext.needLLMParse() && !parseContext.needFeedback()) {
SemanticParseInfo selectedParse = parseContext.getRequest().getSelectedParse();
queryNLReq.setSelectedParseInfo(Objects.nonNull(selectedParse) ? selectedParse
: parseContext.getResponse().getSelectedParses().get(0));
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
parseContext.getResponse().getSelectedParses().clear();
rewriteMultiTurn(parseContext, queryNLReq);
addDynamicExemplars(parseContext, queryNLReq);
doParse(queryNLReq, parseContext.getResponse());
}
}

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
import java.util.Objects;
@@ -20,14 +19,24 @@ public class ParseContext {
}
public boolean enableNL2SQL() {
return agent.containsDatasetTool();
}
public boolean needFeedback() {
return agent.enableFeedback() && Objects.isNull(request.getSelectedParse());
return Objects.nonNull(agent) && agent.containsDatasetTool();
}
public boolean enableLLM() {
return !(needFeedback() || request.isDisableLLM());
return !request.isDisableLLM();
}
public boolean needFeedback() {
return agent.enableFeedback() && (Objects.isNull(request.getSelectedParse())
&& response.getSelectedParses().size() > 1);
}
public boolean needRuleParse() {
return Objects.isNull(request.getSelectedParse());
}
public boolean needLLMParse() {
return enableLLM() && (Objects.nonNull(request.getSelectedParse())
|| !response.getSelectedParses().isEmpty());
}
}

View File

@@ -24,7 +24,7 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
}
});
Collections.sort(sortedParseInfo, (o1, o2) -> o1.getScore() - o2.getScore() >= 0 ? 1 : 0);
sortedParseInfo.sort((o1, o2) -> o1.getScore() - o2.getScore() > 0 ? 1 : 0);
parseContext.getResponse().setSelectedParses(sortedParseInfo);
}
}

View File

@@ -119,13 +119,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
chatManageService.updateParseCostTime(parseContext.getResponse());
}
// no need for explicit user feedback if there is only one candidate parses
if (parseContext.needFeedback() && parseContext.getResponse().getSelectedParses().size() == 1) {
chatParseReq.setQueryId(parseContext.getResponse().getQueryId());
chatParseReq.setSelectedParse(parseContext.getResponse().getSelectedParses().get(0));
return parse(chatParseReq);
}
return parseContext.getResponse();
}

View File

@@ -1,9 +1,11 @@
package com.tencent.supersonic.headless.chat.parser.rule;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
@@ -27,6 +29,7 @@ public class RuleSqlParser implements SemanticParser {
return;
}
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
List<SemanticQuery> candidateQueries = Lists.newArrayList();
// iterate all schemaElementMatches to resolve query mode
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
@@ -36,7 +39,10 @@ public class RuleSqlParser implements SemanticParser {
query.fillParseInfo(chatQueryContext);
chatQueryContext.getCandidateQueries().add(query);
}
candidateQueries.addAll(chatQueryContext.getCandidateQueries());
chatQueryContext.getCandidateQueries().clear();
}
chatQueryContext.setCandidateQueries(candidateQueries);
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
}