[improvement][chat]Sort parses in NL2SQLParser right after rule-based parsing.

This commit is contained in:
jerryjzhang
2024-10-29 23:31:02 +08:00
parent 847505b293
commit 53ddc67262
12 changed files with 23 additions and 59 deletions

View File

@@ -34,15 +34,6 @@ public class LLMRequestService {
@Autowired
private ParserConfig parserConfig;
public boolean isSkip(ChatQueryContext queryCtx) {
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
log.info("LLM disabled, skip");
return true;
}
return false;
}
public Long getDataSetId(ChatQueryContext queryCtx) {
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds());

View File

@@ -25,12 +25,12 @@ public class LLMSqlParser implements SemanticParser {
@Override
public void parse(ChatQueryContext queryCtx) {
try {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
// 1.determine whether to skip this parser.
if (requestService.isSkip(queryCtx)) {
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
return;
}
// 2.get dataSetId from queryCtx and chatCtx.
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
Long dataSetId = requestService.getDataSetId(queryCtx);
if (dataSetId == null) {
return;

View File

@@ -24,8 +24,7 @@ public class RuleSqlParser implements SemanticParser {
@Override
public void parse(ChatQueryContext chatQueryContext) {
if (!chatQueryContext.getRequest().getText2SQLType().enableRule()
|| !chatQueryContext.getCandidateQueries().isEmpty()) {
if (!chatQueryContext.getCandidateQueries().isEmpty()) {
return;
}
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();

View File

@@ -13,8 +13,8 @@ import java.util.concurrent.ConcurrentHashMap;
public class QueryManager {
private static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
private static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
private final static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
private final static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
public static void register(SemanticQuery query) {
if (query instanceof RuleSemanticQuery) {
@@ -73,13 +73,6 @@ public class QueryManager {
return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery;
}
public static RuleSemanticQuery getRuleQuery(String queryMode) {
if (queryMode == null) {
return null;
}
return ruleQueryMap.get(queryMode);
}
public static List<RuleSemanticQuery> getRuleQueries() {
return new ArrayList<>(ruleQueryMap.values());
}