mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][chat]Modify core workflow of NL2SQLParser, always invoking rule-based parsers first.#1729
This commit is contained in:
@@ -11,6 +11,7 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -73,11 +74,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
@Override
|
||||
public void parse(ParseContext parseContext) {
|
||||
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
||||
return;
|
||||
}
|
||||
if (parseContext.needFeedback()) {
|
||||
processFeedback(parseContext);
|
||||
if (!parseContext.enableNL2SQL()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -88,20 +85,25 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
if (parseContext.enableLLM()) {
|
||||
rewriteMultiTurn(parseContext, queryNLReq);
|
||||
addDynamicExemplars(parseContext, queryNLReq);
|
||||
|
||||
if (parseContext.needRuleParse()) {
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
for (MapModeEnum mode : MapModeEnum.values()) {
|
||||
queryNLReq.setMapModeEnum(mode);
|
||||
doParse(queryNLReq, parseResp);
|
||||
}
|
||||
}
|
||||
|
||||
doParse(queryNLReq, parseContext.getResponse());
|
||||
}
|
||||
|
||||
private void processFeedback(ParseContext parseContext) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
for (MapModeEnum mode : MapModeEnum.values()) {
|
||||
queryNLReq.setMapModeEnum(mode);
|
||||
doParse(queryNLReq, parseResp);
|
||||
if (parseContext.needLLMParse() && !parseContext.needFeedback()) {
|
||||
SemanticParseInfo selectedParse = parseContext.getRequest().getSelectedParse();
|
||||
queryNLReq.setSelectedParseInfo(Objects.nonNull(selectedParse) ? selectedParse
|
||||
: parseContext.getResponse().getSelectedParses().get(0));
|
||||
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||
parseContext.getResponse().getSelectedParses().clear();
|
||||
rewriteMultiTurn(parseContext, queryNLReq);
|
||||
addDynamicExemplars(parseContext, queryNLReq);
|
||||
doParse(queryNLReq, parseContext.getResponse());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.pojo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Objects;
|
||||
@@ -20,14 +19,24 @@ public class ParseContext {
|
||||
}
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
return agent.containsDatasetTool();
|
||||
}
|
||||
|
||||
public boolean needFeedback() {
|
||||
return agent.enableFeedback() && Objects.isNull(request.getSelectedParse());
|
||||
return Objects.nonNull(agent) && agent.containsDatasetTool();
|
||||
}
|
||||
|
||||
public boolean enableLLM() {
|
||||
return !(needFeedback() || request.isDisableLLM());
|
||||
return !request.isDisableLLM();
|
||||
}
|
||||
|
||||
public boolean needFeedback() {
|
||||
return agent.enableFeedback() && (Objects.isNull(request.getSelectedParse())
|
||||
&& response.getSelectedParses().size() > 1);
|
||||
}
|
||||
|
||||
public boolean needRuleParse() {
|
||||
return Objects.isNull(request.getSelectedParse());
|
||||
}
|
||||
|
||||
public boolean needLLMParse() {
|
||||
return enableLLM() && (Objects.nonNull(request.getSelectedParse())
|
||||
|| !response.getSelectedParses().isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
|
||||
}
|
||||
});
|
||||
|
||||
Collections.sort(sortedParseInfo, (o1, o2) -> o1.getScore() - o2.getScore() >= 0 ? 1 : 0);
|
||||
sortedParseInfo.sort((o1, o2) -> o1.getScore() - o2.getScore() > 0 ? 1 : 0);
|
||||
parseContext.getResponse().setSelectedParses(sortedParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,13 +119,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
chatManageService.updateParseCostTime(parseContext.getResponse());
|
||||
}
|
||||
|
||||
// no need for explicit user feedback if there is only one candidate parses
|
||||
if (parseContext.needFeedback() && parseContext.getResponse().getSelectedParses().size() == 1) {
|
||||
chatParseReq.setQueryId(parseContext.getResponse().getQueryId());
|
||||
chatParseReq.setSelectedParse(parseContext.getResponse().getSelectedParses().get(0));
|
||||
return parse(chatParseReq);
|
||||
}
|
||||
|
||||
return parseContext.getResponse();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -27,6 +29,7 @@ public class RuleSqlParser implements SemanticParser {
|
||||
return;
|
||||
}
|
||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||
List<SemanticQuery> candidateQueries = Lists.newArrayList();
|
||||
// iterate all schemaElementMatches to resolve query mode
|
||||
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
@@ -36,7 +39,10 @@ public class RuleSqlParser implements SemanticParser {
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
chatQueryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
candidateQueries.addAll(chatQueryContext.getCandidateQueries());
|
||||
chatQueryContext.getCandidateQueries().clear();
|
||||
}
|
||||
chatQueryContext.setCandidateQueries(candidateQueries);
|
||||
|
||||
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user