mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[improvement][chat]Sort parses in NL2SQLParser right after rule-based parsing.
This commit is contained in:
@@ -104,6 +104,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
.collect(Collectors.toList());
|
||||
parseContext.getResponse().getSelectedParses().addAll(sortedParses);
|
||||
}
|
||||
SemanticParseInfo.sort(parseContext.getResponse().getSelectedParses());
|
||||
}
|
||||
|
||||
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
|
||||
**/
|
||||
@Slf4j
|
||||
public class ParseInfoSortProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();
|
||||
selectedParses.sort(new SemanticParseInfo.SemanticParseComparator());
|
||||
// re-assign parseId
|
||||
for (int i = 0; i < selectedParses.size(); i++) {
|
||||
SemanticParseInfo parseInfo = selectedParses.get(i);
|
||||
parseInfo.setId(i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,13 +1,9 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum Text2SQLType {
|
||||
ONLY_RULE, ONLY_LLM, LLM_OR_RULE;
|
||||
|
||||
public boolean enableRule() {
|
||||
return this.equals(ONLY_RULE) || this.equals(LLM_OR_RULE);
|
||||
}
|
||||
ONLY_RULE, LLM_OR_RULE;
|
||||
|
||||
public boolean enableLLM() {
|
||||
return this.equals(ONLY_LLM) || this.equals(LLM_OR_RULE);
|
||||
return this.equals(LLM_OR_RULE);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +98,15 @@ public class SemanticParseInfo {
|
||||
}
|
||||
}
|
||||
|
||||
public static void sort(List<SemanticParseInfo> parses) {
|
||||
parses.sort(new SemanticParseComparator());
|
||||
// re-assign parseId
|
||||
for (int i = 0; i < parses.size(); i++) {
|
||||
SemanticParseInfo parseInfo = parses.get(i);
|
||||
parseInfo.setId(i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||
@Override
|
||||
public int compare(SchemaElement o1, SchemaElement o2) {
|
||||
|
||||
@@ -34,15 +34,6 @@ public class LLMRequestService {
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
public boolean isSkip(ChatQueryContext queryCtx) {
|
||||
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
|
||||
log.info("LLM disabled, skip");
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public Long getDataSetId(ChatQueryContext queryCtx) {
|
||||
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
|
||||
return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds());
|
||||
|
||||
@@ -25,12 +25,12 @@ public class LLMSqlParser implements SemanticParser {
|
||||
@Override
|
||||
public void parse(ChatQueryContext queryCtx) {
|
||||
try {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
// 1.determine whether to skip this parser.
|
||||
if (requestService.isSkip(queryCtx)) {
|
||||
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
|
||||
return;
|
||||
}
|
||||
// 2.get dataSetId from queryCtx and chatCtx.
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
|
||||
@@ -24,8 +24,7 @@ public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
if (!chatQueryContext.getRequest().getText2SQLType().enableRule()
|
||||
|| !chatQueryContext.getCandidateQueries().isEmpty()) {
|
||||
if (!chatQueryContext.getCandidateQueries().isEmpty()) {
|
||||
return;
|
||||
}
|
||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||
|
||||
@@ -13,8 +13,8 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class QueryManager {
|
||||
|
||||
private static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
|
||||
private static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
|
||||
private final static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
|
||||
private final static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static void register(SemanticQuery query) {
|
||||
if (query instanceof RuleSemanticQuery) {
|
||||
@@ -73,13 +73,6 @@ public class QueryManager {
|
||||
return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery;
|
||||
}
|
||||
|
||||
public static RuleSemanticQuery getRuleQuery(String queryMode) {
|
||||
if (queryMode == null) {
|
||||
return null;
|
||||
}
|
||||
return ruleQueryMap.get(queryMode);
|
||||
}
|
||||
|
||||
public static List<RuleSemanticQuery> getRuleQueries() {
|
||||
return new ArrayList<>(ruleQueryMap.values());
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ public class JdbcExecutor implements QueryExecutor {
|
||||
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
|
||||
queryResultWithColumns.setSql(sql);
|
||||
} catch (Exception e) {
|
||||
log.error("queryInternal error [{}]", e);
|
||||
log.error("queryInternal error [{}]", StringUtils.normalizeSpace(e.toString()));
|
||||
queryResultWithColumns.setErrorMsg(e.getMessage());
|
||||
}
|
||||
return queryResultWithColumns;
|
||||
|
||||
@@ -15,7 +15,9 @@ com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
|
||||
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor
|
||||
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
|
||||
|
||||
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
||||
|
||||
@@ -69,8 +69,7 @@ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoSortProcessor
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
|
||||
|
||||
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
||||
|
||||
@@ -157,7 +157,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
private static DatasetTool getDatasetTool() {
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(-1L));
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(1L));
|
||||
|
||||
return datasetTool;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user