[improvement][chat]Sort parses in NL2SQLParser right after rule-based parsing.

This commit is contained in:
jerryjzhang
2024-10-29 23:31:02 +08:00
parent 847505b293
commit 53ddc67262
12 changed files with 23 additions and 59 deletions

View File

@@ -104,6 +104,7 @@ public class NL2SQLParser implements ChatQueryParser {
.collect(Collectors.toList()); .collect(Collectors.toList());
parseContext.getResponse().getSelectedParses().addAll(sortedParses); parseContext.getResponse().getSelectedParses().addAll(sortedParses);
} }
SemanticParseInfo.sort(parseContext.getResponse().getSelectedParses());
} }
// next go with llm-based parsers unless LLM is disabled or use feedback is needed. // next go with llm-based parsers unless LLM is disabled or use feedback is needed.

View File

@@ -1,26 +0,0 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
/**
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
**/
@Slf4j
public class ParseInfoSortProcessor implements ParseResultProcessor {
@Override
public void process(ParseContext parseContext) {
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();
selectedParses.sort(new SemanticParseInfo.SemanticParseComparator());
// re-assign parseId
for (int i = 0; i < selectedParses.size(); i++) {
SemanticParseInfo parseInfo = selectedParses.get(i);
parseInfo.setId(i + 1);
}
}
}

View File

@@ -1,13 +1,9 @@
package com.tencent.supersonic.common.pojo.enums; package com.tencent.supersonic.common.pojo.enums;
public enum Text2SQLType { public enum Text2SQLType {
ONLY_RULE, ONLY_LLM, LLM_OR_RULE; ONLY_RULE, LLM_OR_RULE;
public boolean enableRule() {
return this.equals(ONLY_RULE) || this.equals(LLM_OR_RULE);
}
public boolean enableLLM() { public boolean enableLLM() {
return this.equals(ONLY_LLM) || this.equals(LLM_OR_RULE); return this.equals(LLM_OR_RULE);
} }
} }

View File

@@ -98,6 +98,15 @@ public class SemanticParseInfo {
} }
} }
public static void sort(List<SemanticParseInfo> parses) {
parses.sort(new SemanticParseComparator());
// re-assign parseId
for (int i = 0; i < parses.size(); i++) {
SemanticParseInfo parseInfo = parses.get(i);
parseInfo.setId(i + 1);
}
}
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> { private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override @Override
public int compare(SchemaElement o1, SchemaElement o2) { public int compare(SchemaElement o1, SchemaElement o2) {

View File

@@ -34,15 +34,6 @@ public class LLMRequestService {
@Autowired @Autowired
private ParserConfig parserConfig; private ParserConfig parserConfig;
public boolean isSkip(ChatQueryContext queryCtx) {
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
log.info("LLM disabled, skip");
return true;
}
return false;
}
public Long getDataSetId(ChatQueryContext queryCtx) { public Long getDataSetId(ChatQueryContext queryCtx) {
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver(); DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds()); return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds());

View File

@@ -25,12 +25,12 @@ public class LLMSqlParser implements SemanticParser {
@Override @Override
public void parse(ChatQueryContext queryCtx) { public void parse(ChatQueryContext queryCtx) {
try { try {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
// 1.determine whether to skip this parser. // 1.determine whether to skip this parser.
if (requestService.isSkip(queryCtx)) { if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
return; return;
} }
// 2.get dataSetId from queryCtx and chatCtx. // 2.get dataSetId from queryCtx and chatCtx.
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
Long dataSetId = requestService.getDataSetId(queryCtx); Long dataSetId = requestService.getDataSetId(queryCtx);
if (dataSetId == null) { if (dataSetId == null) {
return; return;

View File

@@ -24,8 +24,7 @@ public class RuleSqlParser implements SemanticParser {
@Override @Override
public void parse(ChatQueryContext chatQueryContext) { public void parse(ChatQueryContext chatQueryContext) {
if (!chatQueryContext.getRequest().getText2SQLType().enableRule() if (!chatQueryContext.getCandidateQueries().isEmpty()) {
|| !chatQueryContext.getCandidateQueries().isEmpty()) {
return; return;
} }
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo(); SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();

View File

@@ -13,8 +13,8 @@ import java.util.concurrent.ConcurrentHashMap;
public class QueryManager { public class QueryManager {
private static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>(); private final static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
private static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>(); private final static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
public static void register(SemanticQuery query) { public static void register(SemanticQuery query) {
if (query instanceof RuleSemanticQuery) { if (query instanceof RuleSemanticQuery) {
@@ -73,13 +73,6 @@ public class QueryManager {
return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery; return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery;
} }
public static RuleSemanticQuery getRuleQuery(String queryMode) {
if (queryMode == null) {
return null;
}
return ruleQueryMap.get(queryMode);
}
public static List<RuleSemanticQuery> getRuleQueries() { public static List<RuleSemanticQuery> getRuleQueries() {
return new ArrayList<>(ruleQueryMap.values()); return new ArrayList<>(ruleQueryMap.values());
} }

View File

@@ -45,7 +45,7 @@ public class JdbcExecutor implements QueryExecutor {
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns); sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
queryResultWithColumns.setSql(sql); queryResultWithColumns.setSql(sql);
} catch (Exception e) { } catch (Exception e) {
log.error("queryInternal error [{}]", e); log.error("queryInternal error [{}]", StringUtils.normalizeSpace(e.toString()));
queryResultWithColumns.setErrorMsg(e.getMessage()); queryResultWithColumns.setErrorMsg(e.getMessage());
} }
return queryResultWithColumns; return queryResultWithColumns;

View File

@@ -15,7 +15,9 @@ com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\ com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\ com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\ com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\

View File

@@ -69,8 +69,7 @@ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\ com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\ com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\ com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor,\ com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
com.tencent.supersonic.chat.server.processor.parse.ParseInfoSortProcessor
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\ com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\ com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\

View File

@@ -157,7 +157,7 @@ public class Text2SQLEval extends BaseTest {
private static DatasetTool getDatasetTool() { private static DatasetTool getDatasetTool() {
DatasetTool datasetTool = new DatasetTool(); DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET); datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(-1L)); datasetTool.setDataSetIds(Lists.newArrayList(1L));
return datasetTool; return datasetTool;
} }