diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 378a4e994..8f84de6a2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -104,6 +104,7 @@ public class NL2SQLParser implements ChatQueryParser { .collect(Collectors.toList()); parseContext.getResponse().getSelectedParses().addAll(sortedParses); } + SemanticParseInfo.sort(parseContext.getResponse().getSelectedParses()); } // next go with llm-based parsers unless LLM is disabled or use feedback is needed. diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java deleted file mode 100644 index 8628a995e..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java +++ /dev/null @@ -1,26 +0,0 @@ -package com.tencent.supersonic.chat.server.processor.parse; - -import com.tencent.supersonic.chat.server.pojo.ParseContext; -import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import lombok.extern.slf4j.Slf4j; - -import java.util.*; - -/** - * ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \ - **/ -@Slf4j -public class ParseInfoSortProcessor implements ParseResultProcessor { - - @Override - public void process(ParseContext parseContext) { - List selectedParses = parseContext.getResponse().getSelectedParses(); - selectedParses.sort(new SemanticParseInfo.SemanticParseComparator()); - // re-assign parseId - for (int i = 0; i < selectedParses.size(); i++) { - SemanticParseInfo parseInfo = selectedParses.get(i); - parseInfo.setId(i + 1); - } - } - -} diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java index f38ed6efb..150f4f1a3 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java @@ -1,13 +1,9 @@ package com.tencent.supersonic.common.pojo.enums; public enum Text2SQLType { - ONLY_RULE, ONLY_LLM, LLM_OR_RULE; - - public boolean enableRule() { - return this.equals(ONLY_RULE) || this.equals(LLM_OR_RULE); - } + ONLY_RULE, LLM_OR_RULE; public boolean enableLLM() { - return this.equals(ONLY_LLM) || this.equals(LLM_OR_RULE); + return this.equals(LLM_OR_RULE); } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java index 20fa1d489..8645f7776 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java @@ -98,6 +98,15 @@ public class SemanticParseInfo { } } + public static void sort(List parses) { + parses.sort(new SemanticParseComparator()); + // re-assign parseId + for (int i = 0; i < parses.size(); i++) { + SemanticParseInfo parseInfo = parses.get(i); + parseInfo.setId(i + 1); + } + } + private static class SchemaNameLengthComparator implements Comparator { @Override public int compare(SchemaElement o1, SchemaElement o2) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index dc6f87405..39cebf538 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -34,15 +34,6 @@ public class LLMRequestService { @Autowired private ParserConfig parserConfig; - public boolean isSkip(ChatQueryContext queryCtx) { - if (!queryCtx.getRequest().getText2SQLType().enableLLM()) { - log.info("LLM disabled, skip"); - return true; - } - - return false; - } - public Long getDataSetId(ChatQueryContext queryCtx) { DataSetResolver dataSetResolver = ComponentFactory.getModelResolver(); return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java index 375292419..8bcc95ac3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java @@ -25,12 +25,12 @@ public class LLMSqlParser implements SemanticParser { @Override public void parse(ChatQueryContext queryCtx) { try { - LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); // 1.determine whether to skip this parser. - if (requestService.isSkip(queryCtx)) { + if (!queryCtx.getRequest().getText2SQLType().enableLLM()) { return; } // 2.get dataSetId from queryCtx and chatCtx. + LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); Long dataSetId = requestService.getDataSetId(queryCtx); if (dataSetId == null) { return; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index c8be76c2b..34a93975c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -24,8 +24,7 @@ public class RuleSqlParser implements SemanticParser { @Override public void parse(ChatQueryContext chatQueryContext) { - if (!chatQueryContext.getRequest().getText2SQLType().enableRule() - || !chatQueryContext.getCandidateQueries().isEmpty()) { + if (!chatQueryContext.getCandidateQueries().isEmpty()) { return; } SchemaMapInfo mapInfo = chatQueryContext.getMapInfo(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java index 752239e46..71c9b64fa 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java @@ -13,8 +13,8 @@ import java.util.concurrent.ConcurrentHashMap; public class QueryManager { - private static Map ruleQueryMap = new ConcurrentHashMap<>(); - private static Map llmQueryMap = new ConcurrentHashMap<>(); + private final static Map ruleQueryMap = new ConcurrentHashMap<>(); + private final static Map llmQueryMap = new ConcurrentHashMap<>(); public static void register(SemanticQuery query) { if (query instanceof RuleSemanticQuery) { @@ -73,13 +73,6 @@ public class QueryManager { return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery; } - public static RuleSemanticQuery getRuleQuery(String queryMode) { - if (queryMode == null) { - return null; - } - return ruleQueryMap.get(queryMode); - } - public static List getRuleQueries() { return new ArrayList<>(ruleQueryMap.values()); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java index be4791287..6c4e8ed85 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java @@ -45,7 +45,7 @@ public class JdbcExecutor implements QueryExecutor { sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns); queryResultWithColumns.setSql(sql); } catch (Exception e) { - log.error("queryInternal error [{}]", e); + log.error("queryInternal error [{}]", StringUtils.normalizeSpace(e.toString())); queryResultWithColumns.setErrorMsg(e.getMessage()); } return queryResultWithColumns; diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index 4b0a855db..77c3b9eec 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -15,7 +15,9 @@ com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\ com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\ - com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor + com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\ + com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\ + com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\ com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\ diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 35e13a85a..2a5caf958 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -69,8 +69,7 @@ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\ com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\ com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\ com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\ - com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor,\ - com.tencent.supersonic.chat.server.processor.parse.ParseInfoSortProcessor + com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\ com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\ diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index de395526c..e158c7358 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -157,7 +157,7 @@ public class Text2SQLEval extends BaseTest { private static DatasetTool getDatasetTool() { DatasetTool datasetTool = new DatasetTool(); datasetTool.setType(AgentToolType.DATASET); - datasetTool.setDataSetIds(Lists.newArrayList(-1L)); + datasetTool.setDataSetIds(Lists.newArrayList(1L)); return datasetTool; }