From 5f6e9ae1947d854850fd12d668db8ffc544b5e6a Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Sun, 7 Apr 2024 14:33:17 +0800 Subject: [PATCH] (improvement)(Headless) Add Text2SQLType to control whether rules and large models are passed (#891) Co-authored-by: jolunoluo --- .../supersonic/chat/server/agent/Agent.java | 4 ++++ .../chat/server/util/QueryReqConverter.java | 9 +++++++-- .../common/pojo/enums/Text2SQLType.java | 15 +++++++++++++++ .../headless/api/pojo/request/QueryReq.java | 3 ++- .../core/chat/parser/llm/LLMRequestService.java | 14 +++++++------- .../core/chat/parser/rule/RuleSqlParser.java | 6 +++++- .../headless/core/pojo/QueryContext.java | 12 +++++++----- .../server/service/impl/ChatQueryServiceImpl.java | 2 +- 8 files changed, 48 insertions(+), 17 deletions(-) create mode 100644 common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 267607363..f6f5e2a3c 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -70,6 +70,10 @@ public class Agent extends RecordInfo { return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); } + public boolean containsRuleTool() { + return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE)); + } + public boolean containsNL2SQLTool() { return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)) || !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE)); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 60c506aba..c1328e3da 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.util; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import org.apache.commons.collections.MapUtils; @@ -17,8 +18,12 @@ public class QueryReqConverter { if (agent == null) { return queryReq; } - if (agent.containsLLMParserTool()) { - queryReq.setEnableLLM(true); + if (agent.containsLLMParserTool() && agent.containsRuleTool()) { + queryReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); + } else if (agent.containsLLMParserTool()) { + queryReq.setText2SQLType(Text2SQLType.ONLY_LLM); + } else if (agent.containsRuleTool()) { + queryReq.setText2SQLType(Text2SQLType.ONLY_RULE); } queryReq.setDataSetIds(agent.getDataSetIds()); if (Objects.nonNull(queryReq.getMapInfo()) diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java new file mode 100644 index 000000000..b40f06b73 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java @@ -0,0 +1,15 @@ +package com.tencent.supersonic.common.pojo.enums; + +public enum Text2SQLType { + + ONLY_RULE, ONLY_LLM, RULE_AND_LLM; + + public boolean enableRule() { + return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM); + } + + public boolean enableLLM() { + return this.equals(ONLY_LLM) || this.equals(RULE_AND_LLM); + } + +} diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java index 15c39df12..dea2fc0cf 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import lombok.Data; @@ -15,6 +16,6 @@ public class QueryReq { private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; - private boolean enableLLM; + private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; private SchemaMapInfo mapInfo = new SchemaMapInfo(); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java index 5410870c9..14141ef87 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java @@ -16,6 +16,12 @@ import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.utils.ComponentFactory; import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.Comparator; import java.util.HashSet; @@ -24,12 +30,6 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import org.springframework.util.CollectionUtils; @Slf4j @Service @@ -41,7 +41,7 @@ public class LLMRequestService { private OptimizationConfig optimizationConfig; public boolean isSkip(QueryContext queryCtx) { - if (!queryCtx.isEnableLLM()) { + if (!queryCtx.getText2SQLType().enableLLM()) { log.info("not enable llm, skip"); return true; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/RuleSqlParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/RuleSqlParser.java index 92de7be6e..b59fdf00c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/RuleSqlParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/RuleSqlParser.java @@ -6,9 +6,9 @@ import com.tencent.supersonic.headless.core.chat.parser.SemanticParser; import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.core.pojo.ChatContext; import com.tencent.supersonic.headless.core.pojo.QueryContext; +import lombok.extern.slf4j.Slf4j; import java.util.Arrays; import java.util.List; -import lombok.extern.slf4j.Slf4j; /** * RuleSqlParser resolves a specific SemanticQuery according to co-appearance @@ -25,6 +25,10 @@ public class RuleSqlParser implements SemanticParser { @Override public void parse(QueryContext queryContext, ChatContext chatContext) { + if (!queryContext.getText2SQLType().enableRule()) { + log.info("not enable rule, skip"); + return; + } SchemaMapInfo mapInfo = queryContext.getMapInfo(); // iterate all schemaElementMatches to resolve query mode for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java index d7f9dec1e..c151bfd77 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java @@ -2,22 +2,24 @@ package com.tencent.supersonic.headless.core.pojo; import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.core.chat.query.SemanticQuery; import com.tencent.supersonic.headless.core.config.OptimizationConfig; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; @Data @Builder @@ -31,7 +33,7 @@ public class QueryContext { private Map> modelIdToDataSetIds; private User user; private boolean saveAnswer; - private boolean enableLLM; + private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; private QueryFilters queryFilters; private List candidateQueries = new ArrayList<>(); private SchemaMapInfo mapInfo = new SchemaMapInfo(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index a461bfcf6..64ef10414 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -177,7 +177,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { .candidateQueries(new ArrayList<>()) .mapInfo(new SchemaMapInfo()) .modelIdToDataSetIds(modelIdToDataSetIds) - .enableLLM(queryReq.isEnableLLM()) + .text2SQLType(queryReq.getText2SQLType()) .build(); BeanUtils.copyProperties(queryReq, queryCtx); return queryCtx;