mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[improvement][chat]Sort parses in NL2SQLParser right after rule-based parsing.
This commit is contained in:
@@ -104,6 +104,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
parseContext.getResponse().getSelectedParses().addAll(sortedParses);
|
parseContext.getResponse().getSelectedParses().addAll(sortedParses);
|
||||||
}
|
}
|
||||||
|
SemanticParseInfo.sort(parseContext.getResponse().getSelectedParses());
|
||||||
}
|
}
|
||||||
|
|
||||||
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
|
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.parse;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
|
|
||||||
**/
|
|
||||||
@Slf4j
|
|
||||||
public class ParseInfoSortProcessor implements ParseResultProcessor {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void process(ParseContext parseContext) {
|
|
||||||
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();
|
|
||||||
selectedParses.sort(new SemanticParseInfo.SemanticParseComparator());
|
|
||||||
// re-assign parseId
|
|
||||||
for (int i = 0; i < selectedParses.size(); i++) {
|
|
||||||
SemanticParseInfo parseInfo = selectedParses.get(i);
|
|
||||||
parseInfo.setId(i + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,13 +1,9 @@
|
|||||||
package com.tencent.supersonic.common.pojo.enums;
|
package com.tencent.supersonic.common.pojo.enums;
|
||||||
|
|
||||||
public enum Text2SQLType {
|
public enum Text2SQLType {
|
||||||
ONLY_RULE, ONLY_LLM, LLM_OR_RULE;
|
ONLY_RULE, LLM_OR_RULE;
|
||||||
|
|
||||||
public boolean enableRule() {
|
|
||||||
return this.equals(ONLY_RULE) || this.equals(LLM_OR_RULE);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean enableLLM() {
|
public boolean enableLLM() {
|
||||||
return this.equals(ONLY_LLM) || this.equals(LLM_OR_RULE);
|
return this.equals(LLM_OR_RULE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,6 +98,15 @@ public class SemanticParseInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void sort(List<SemanticParseInfo> parses) {
|
||||||
|
parses.sort(new SemanticParseComparator());
|
||||||
|
// re-assign parseId
|
||||||
|
for (int i = 0; i < parses.size(); i++) {
|
||||||
|
SemanticParseInfo parseInfo = parses.get(i);
|
||||||
|
parseInfo.setId(i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||||
@Override
|
@Override
|
||||||
public int compare(SchemaElement o1, SchemaElement o2) {
|
public int compare(SchemaElement o1, SchemaElement o2) {
|
||||||
|
|||||||
@@ -34,15 +34,6 @@ public class LLMRequestService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ParserConfig parserConfig;
|
private ParserConfig parserConfig;
|
||||||
|
|
||||||
public boolean isSkip(ChatQueryContext queryCtx) {
|
|
||||||
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
|
|
||||||
log.info("LLM disabled, skip");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Long getDataSetId(ChatQueryContext queryCtx) {
|
public Long getDataSetId(ChatQueryContext queryCtx) {
|
||||||
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
|
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
|
||||||
return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds());
|
return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds());
|
||||||
|
|||||||
@@ -25,12 +25,12 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
@Override
|
@Override
|
||||||
public void parse(ChatQueryContext queryCtx) {
|
public void parse(ChatQueryContext queryCtx) {
|
||||||
try {
|
try {
|
||||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
|
||||||
// 1.determine whether to skip this parser.
|
// 1.determine whether to skip this parser.
|
||||||
if (requestService.isSkip(queryCtx)) {
|
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// 2.get dataSetId from queryCtx and chatCtx.
|
// 2.get dataSetId from queryCtx and chatCtx.
|
||||||
|
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||||
if (dataSetId == null) {
|
if (dataSetId == null) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ public class RuleSqlParser implements SemanticParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatQueryContext chatQueryContext) {
|
public void parse(ChatQueryContext chatQueryContext) {
|
||||||
if (!chatQueryContext.getRequest().getText2SQLType().enableRule()
|
if (!chatQueryContext.getCandidateQueries().isEmpty()) {
|
||||||
|| !chatQueryContext.getCandidateQueries().isEmpty()) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||||||
|
|
||||||
public class QueryManager {
|
public class QueryManager {
|
||||||
|
|
||||||
private static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
|
private final static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
|
||||||
private static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
|
private final static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
public static void register(SemanticQuery query) {
|
public static void register(SemanticQuery query) {
|
||||||
if (query instanceof RuleSemanticQuery) {
|
if (query instanceof RuleSemanticQuery) {
|
||||||
@@ -73,13 +73,6 @@ public class QueryManager {
|
|||||||
return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery;
|
return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static RuleSemanticQuery getRuleQuery(String queryMode) {
|
|
||||||
if (queryMode == null) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return ruleQueryMap.get(queryMode);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static List<RuleSemanticQuery> getRuleQueries() {
|
public static List<RuleSemanticQuery> getRuleQueries() {
|
||||||
return new ArrayList<>(ruleQueryMap.values());
|
return new ArrayList<>(ruleQueryMap.values());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ public class JdbcExecutor implements QueryExecutor {
|
|||||||
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
|
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
|
||||||
queryResultWithColumns.setSql(sql);
|
queryResultWithColumns.setSql(sql);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("queryInternal error [{}]", e);
|
log.error("queryInternal error [{}]", StringUtils.normalizeSpace(e.toString()));
|
||||||
queryResultWithColumns.setErrorMsg(e.getMessage());
|
queryResultWithColumns.setErrorMsg(e.getMessage());
|
||||||
}
|
}
|
||||||
return queryResultWithColumns;
|
return queryResultWithColumns;
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
|
|||||||
|
|
||||||
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor
|
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
|
||||||
|
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
|
||||||
|
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
|
||||||
|
|
||||||
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
||||||
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
||||||
|
|||||||
@@ -69,8 +69,7 @@ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
|||||||
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
|
||||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoSortProcessor
|
|
||||||
|
|
||||||
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
||||||
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
private static DatasetTool getDatasetTool() {
|
private static DatasetTool getDatasetTool() {
|
||||||
DatasetTool datasetTool = new DatasetTool();
|
DatasetTool datasetTool = new DatasetTool();
|
||||||
datasetTool.setType(AgentToolType.DATASET);
|
datasetTool.setType(AgentToolType.DATASET);
|
||||||
datasetTool.setDataSetIds(Lists.newArrayList(-1L));
|
datasetTool.setDataSetIds(Lists.newArrayList(1L));
|
||||||
|
|
||||||
return datasetTool;
|
return datasetTool;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user