mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
[improvement][chat]Modify core workflow of NL2SQLParser, always invoking rule-based parsers first.#1729
This commit is contained in:
@@ -11,6 +11,7 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -73,11 +74,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ParseContext parseContext) {
|
public void parse(ParseContext parseContext) {
|
||||||
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
if (!parseContext.enableNL2SQL()) {
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (parseContext.needFeedback()) {
|
|
||||||
processFeedback(parseContext);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,20 +85,25 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
|
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
|
||||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||||
}
|
}
|
||||||
if (parseContext.enableLLM()) {
|
|
||||||
rewriteMultiTurn(parseContext, queryNLReq);
|
if (parseContext.needRuleParse()) {
|
||||||
addDynamicExemplars(parseContext, queryNLReq);
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||||
|
ChatParseResp parseResp = parseContext.getResponse();
|
||||||
|
for (MapModeEnum mode : MapModeEnum.values()) {
|
||||||
|
queryNLReq.setMapModeEnum(mode);
|
||||||
|
doParse(queryNLReq, parseResp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
doParse(queryNLReq, parseContext.getResponse());
|
if (parseContext.needLLMParse() && !parseContext.needFeedback()) {
|
||||||
}
|
SemanticParseInfo selectedParse = parseContext.getRequest().getSelectedParse();
|
||||||
|
queryNLReq.setSelectedParseInfo(Objects.nonNull(selectedParse) ? selectedParse
|
||||||
private void processFeedback(ParseContext parseContext) {
|
: parseContext.getResponse().getSelectedParses().get(0));
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||||
ChatParseResp parseResp = parseContext.getResponse();
|
parseContext.getResponse().getSelectedParses().clear();
|
||||||
for (MapModeEnum mode : MapModeEnum.values()) {
|
rewriteMultiTurn(parseContext, queryNLReq);
|
||||||
queryNLReq.setMapModeEnum(mode);
|
addDynamicExemplars(parseContext, queryNLReq);
|
||||||
doParse(queryNLReq, parseResp);
|
doParse(queryNLReq, parseContext.getResponse());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.pojo;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -20,14 +19,24 @@ public class ParseContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableNL2SQL() {
|
public boolean enableNL2SQL() {
|
||||||
return agent.containsDatasetTool();
|
return Objects.nonNull(agent) && agent.containsDatasetTool();
|
||||||
}
|
|
||||||
|
|
||||||
public boolean needFeedback() {
|
|
||||||
return agent.enableFeedback() && Objects.isNull(request.getSelectedParse());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableLLM() {
|
public boolean enableLLM() {
|
||||||
return !(needFeedback() || request.isDisableLLM());
|
return !request.isDisableLLM();
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean needFeedback() {
|
||||||
|
return agent.enableFeedback() && (Objects.isNull(request.getSelectedParse())
|
||||||
|
&& response.getSelectedParses().size() > 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean needRuleParse() {
|
||||||
|
return Objects.isNull(request.getSelectedParse());
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean needLLMParse() {
|
||||||
|
return enableLLM() && (Objects.nonNull(request.getSelectedParse())
|
||||||
|
|| !response.getSelectedParses().isEmpty());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
Collections.sort(sortedParseInfo, (o1, o2) -> o1.getScore() - o2.getScore() >= 0 ? 1 : 0);
|
sortedParseInfo.sort((o1, o2) -> o1.getScore() - o2.getScore() > 0 ? 1 : 0);
|
||||||
parseContext.getResponse().setSelectedParses(sortedParseInfo);
|
parseContext.getResponse().setSelectedParses(sortedParseInfo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,13 +119,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
chatManageService.updateParseCostTime(parseContext.getResponse());
|
chatManageService.updateParseCostTime(parseContext.getResponse());
|
||||||
}
|
}
|
||||||
|
|
||||||
// no need for explicit user feedback if there is only one candidate parses
|
|
||||||
if (parseContext.needFeedback() && parseContext.getResponse().getSelectedParses().size() == 1) {
|
|
||||||
chatParseReq.setQueryId(parseContext.getResponse().getQueryId());
|
|
||||||
chatParseReq.setSelectedParse(parseContext.getResponse().getSelectedParses().get(0));
|
|
||||||
return parse(chatParseReq);
|
|
||||||
}
|
|
||||||
|
|
||||||
return parseContext.getResponse();
|
return parseContext.getResponse();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.rule;
|
package com.tencent.supersonic.headless.chat.parser.rule;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@@ -27,6 +29,7 @@ public class RuleSqlParser implements SemanticParser {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||||
|
List<SemanticQuery> candidateQueries = Lists.newArrayList();
|
||||||
// iterate all schemaElementMatches to resolve query mode
|
// iterate all schemaElementMatches to resolve query mode
|
||||||
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
||||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||||
@@ -36,7 +39,10 @@ public class RuleSqlParser implements SemanticParser {
|
|||||||
query.fillParseInfo(chatQueryContext);
|
query.fillParseInfo(chatQueryContext);
|
||||||
chatQueryContext.getCandidateQueries().add(query);
|
chatQueryContext.getCandidateQueries().add(query);
|
||||||
}
|
}
|
||||||
|
candidateQueries.addAll(chatQueryContext.getCandidateQueries());
|
||||||
|
chatQueryContext.getCandidateQueries().clear();
|
||||||
}
|
}
|
||||||
|
chatQueryContext.setCandidateQueries(candidateQueries);
|
||||||
|
|
||||||
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user