From 85552fb73a08bb29a9c45c8d6fb344f6f65fb04e Mon Sep 17 00:00:00 2001 From: Jun Zhang Date: Thu, 10 Oct 2024 14:53:54 +0800 Subject: [PATCH] [improvement][chat]Merge `RuleTool` and `LLMTool`.#1766 (#1768) --- .../supersonic/chat/server/agent/Agent.java | 21 ++++++------------- .../chat/server/agent/AgentToolType.java | 7 +++---- .../{NL2SQLTool.java => DatasetTool.java} | 5 +++-- .../chat/server/agent/LLMParserTool.java | 11 ---------- .../chat/server/agent/RuleParserTool.java | 18 ---------------- .../chat/server/parser/NL2SQLParser.java | 4 ++-- .../chat/server/pojo/ParseContext.java | 9 +------- .../server/service/impl/AgentServiceImpl.java | 2 +- .../chat/server/util/QueryReqConverter.java | 16 +++----------- .../tencent/supersonic/demo/DuSQLDemo.java | 12 +++++------ .../tencent/supersonic/demo/S2ArtistDemo.java | 21 ++++++------------- .../tencent/supersonic/demo/S2VisitsDemo.java | 21 +++++++------------ .../supersonic/evaluation/Text2SQLEval.java | 10 ++++----- .../tencent/supersonic/util/DataUtils.java | 1 + 14 files changed, 44 insertions(+), 114 deletions(-) rename chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/{NL2SQLTool.java => DatasetTool.java} (63%) delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/LLMParserTool.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/RuleParserTool.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index c41fd3762..6d73b1d01 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -54,12 +54,12 @@ public class Agent extends RecordInfo { return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L); } - public List getParserTools(AgentToolType agentToolType) { + public List getParserTools(AgentToolType agentToolType) { List tools = this.getTools(agentToolType); if (CollectionUtils.isEmpty(tools)) { return Lists.newArrayList(); } - return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class)) + return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class)) .collect(Collectors.toList()); } @@ -67,17 +67,8 @@ public class Agent extends RecordInfo { return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN)); } - public boolean containsLLMTool() { - return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); - } - - public boolean containsRuleTool() { - return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE)); - } - - public boolean containsNL2SQLTool() { - return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)) - || !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE)); + public boolean containsDatasetTool() { + return !CollectionUtils.isEmpty(getParserTools(AgentToolType.DATASET)); } public boolean containsAnyTool() { @@ -102,11 +93,11 @@ public class Agent extends RecordInfo { } public Set getDataSetIds(AgentToolType agentToolType) { - List commonAgentTools = getParserTools(agentToolType); + List commonAgentTools = getParserTools(agentToolType); if (CollectionUtils.isEmpty(commonAgentTools)) { return new HashSet<>(); } - return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds) + return commonAgentTools.stream().map(DatasetTool::getDataSetIds) .filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream) .collect(Collectors.toSet()); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/AgentToolType.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/AgentToolType.java index 45ea4355e..d34c1e8df 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/AgentToolType.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/AgentToolType.java @@ -4,9 +4,9 @@ import java.util.HashMap; import java.util.Map; public enum AgentToolType { - NL2SQL_RULE("基于规则Text-to-SQL"), NL2SQL_LLM("基于大模型Text-to-SQL"), PLUGIN("第三方插件"); + DATASET("Text2SQL数据集"), PLUGIN("第三方插件"); - private String title; + private final String title; AgentToolType(String title) { this.title = title; @@ -14,8 +14,7 @@ public enum AgentToolType { public static Map getToolTypes() { Map map = new HashMap<>(); - map.put(NL2SQL_RULE, NL2SQL_RULE.title); - map.put(NL2SQL_LLM, NL2SQL_LLM.title); + map.put(DATASET, DATASET.title); map.put(PLUGIN, PLUGIN.title); return map; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/NL2SQLTool.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/DatasetTool.java similarity index 63% rename from chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/NL2SQLTool.java rename to chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/DatasetTool.java index 59f48bb38..6203599f7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/NL2SQLTool.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/DatasetTool.java @@ -9,7 +9,8 @@ import java.util.List; @Data @NoArgsConstructor @AllArgsConstructor -public class NL2SQLTool extends AgentTool { +public class DatasetTool extends AgentTool { - protected List dataSetIds; + private List dataSetIds; + private List exampleQuestions; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/LLMParserTool.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/LLMParserTool.java deleted file mode 100644 index 4921d0969..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/LLMParserTool.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.tencent.supersonic.chat.server.agent; - -import lombok.Data; - -import java.util.List; - -@Data -public class LLMParserTool extends NL2SQLTool { - - private List exampleQuestions; -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/RuleParserTool.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/RuleParserTool.java deleted file mode 100644 index 0596f7ad0..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/RuleParserTool.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.tencent.supersonic.chat.server.agent; - -import lombok.Data; -import org.apache.commons.collections.CollectionUtils; - -import java.util.List; - -@Data -public class RuleParserTool extends NL2SQLTool { - - private List queryModes; - - private List queryTypes; - - public boolean isContainsAllModel() { - return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L); - } -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 57faf19fc..9dbdd67a8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -85,7 +85,7 @@ public class NL2SQLParser implements ChatQueryParser { ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId()); - if (parseContext.enbaleLLM()) { + if (!parseContext.isDisableLLM()) { processMultiTurn(parseContext); } QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx); @@ -96,7 +96,7 @@ public class NL2SQLParser implements ChatQueryParser { if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); } else { - if (parseContext.enbaleLLM()) { + if (!parseContext.isDisableLLM()) { parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(), text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(), parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig( diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index f3716be85..bdc5f93bc 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -21,13 +21,6 @@ public class ParseContext { if (agent == null) { return false; } - return agent.containsNL2SQLTool(); - } - - public boolean enbaleLLM() { - if (agent == null || disableLLM) { - return false; - } - return agent.containsLLMTool(); + return agent.containsDatasetTool(); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 7353d73fc..dda61aa1b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -105,7 +105,7 @@ public class AgentServiceImpl extends ServiceImpl implem } private synchronized void doExecuteAgentExamples(Agent agent) { - if (!agent.containsLLMTool() + if (!agent.containsDatasetTool() || !ModelConfigHelper.testConnection( ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL)) || CollectionUtils.isEmpty(agent.getExamples())) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 223943a1c..026fc14fa 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -26,22 +26,10 @@ public class QueryReqConverter { return queryNLReq; } - ChatModelConfig chatModelConfig = - ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL); - boolean hasLLMTool = agent.containsLLMTool(); - boolean hasRuleTool = agent.containsRuleTool(); - boolean hasLLMConfig = chatModelConfig != null; - if (parseContext.isDisableLLM()) { queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); - } else if (hasLLMTool && hasLLMConfig) { - queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM); - } else if (hasLLMTool && hasRuleTool) { + } else { queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); - } else if (hasLLMTool) { - queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM); - } else if (hasRuleTool) { - queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); } queryNLReq.setDataSetIds(agent.getDataSetIds()); @@ -49,6 +37,8 @@ public class QueryReqConverter { && MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) { queryNLReq.setMapInfo(queryNLReq.getMapInfo()); } + ChatModelConfig chatModelConfig = + ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL); queryNLReq.setModelConfig(chatModelConfig); queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate()); if (chatCtx != null) { diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java index 308a0b749..6c1feb2c2 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java @@ -4,7 +4,7 @@ import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; -import com.tencent.supersonic.chat.server.agent.LLMParserTool; +import com.tencent.supersonic.chat.server.agent.DatasetTool; import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.ModelRela; @@ -335,11 +335,11 @@ public class DuSQLDemo extends S2BaseDemo { agent.setExamples(Lists.newArrayList()); ToolConfig toolConfig = new ToolConfig(); - LLMParserTool llmParserTool = new LLMParserTool(); - llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.NL2SQL_LLM); - llmParserTool.setDataSetIds(Lists.newArrayList(4L)); - toolConfig.getTools().add(llmParserTool); + DatasetTool datasetTool = new DatasetTool(); + datasetTool.setId("1"); + datasetTool.setType(AgentToolType.DATASET); + datasetTool.setDataSetIds(Lists.newArrayList(4L)); + toolConfig.getTools().add(datasetTool); agent.setToolConfig(JSONObject.toJSONString(toolConfig)); log.info("agent:{}", JsonUtil.toString(agent)); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java index d87beca7f..ed3d0c6fd 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java @@ -4,8 +4,7 @@ import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; -import com.tencent.supersonic.chat.server.agent.LLMParserTool; -import com.tencent.supersonic.chat.server.agent.RuleParserTool; +import com.tencent.supersonic.chat.server.agent.DatasetTool; import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.TypeEnums; @@ -169,19 +168,11 @@ public class S2ArtistDemo extends S2BaseDemo { agent.setEnableSearch(1); agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派")); ToolConfig toolConfig = new ToolConfig(); - RuleParserTool ruleQueryTool = new RuleParserTool(); - ruleQueryTool.setId("0"); - ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); - ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId)); - toolConfig.getTools().add(ruleQueryTool); - - if (demoEnableLlm) { - LLMParserTool llmParserTool = new LLMParserTool(); - llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.NL2SQL_LLM); - llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId)); - toolConfig.getTools().add(llmParserTool); - } + DatasetTool datasetTool = new DatasetTool(); + datasetTool.setId("1"); + datasetTool.setType(AgentToolType.DATASET); + datasetTool.setDataSetIds(Lists.newArrayList(dataSetId)); + toolConfig.getTools().add(datasetTool); agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agentService.createAgent(agent, defaultUser); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index e1b118e2d..b7d95f0e3 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -8,9 +8,8 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; -import com.tencent.supersonic.chat.server.agent.LLMParserTool; +import com.tencent.supersonic.chat.server.agent.DatasetTool; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; -import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; @@ -151,18 +150,12 @@ public class S2VisitsDemo extends S2BaseDemo { "过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长")); // configure tools ToolConfig toolConfig = new ToolConfig(); - RuleParserTool ruleQueryTool = new RuleParserTool(); - ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); - ruleQueryTool.setId("0"); - ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId)); - toolConfig.getTools().add(ruleQueryTool); - if (demoEnableLlm) { - LLMParserTool llmParserTool = new LLMParserTool(); - llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.NL2SQL_LLM); - llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId)); - toolConfig.getTools().add(llmParserTool); - } + DatasetTool datasetTool = new DatasetTool(); + datasetTool.setId("1"); + datasetTool.setType(AgentToolType.DATASET); + datasetTool.setDataSetIds(Lists.newArrayList(dataSetId)); + toolConfig.getTools().add(datasetTool); + agent.setToolConfig(JSONObject.toJSONString(toolConfig)); // configure chat models Map chatModelConfig = Maps.newHashMap(); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 33909ebc3..9b628d44f 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -148,11 +148,11 @@ public class Text2SQLEval extends BaseTest { return agent; } - private static LLMParserTool getLLMQueryTool() { - LLMParserTool llmParserTool = new LLMParserTool(); - llmParserTool.setType(AgentToolType.NL2SQL_LLM); - llmParserTool.setDataSetIds(Lists.newArrayList(-1L)); + private static DatasetTool getLLMQueryTool() { + DatasetTool datasetTool = new DatasetTool(); + datasetTool.setType(AgentToolType.DATASET); + datasetTool.setDataSetIds(Lists.newArrayList(-1L)); - return llmParserTool; + return datasetTool; } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java index ca2b98485..a1a5441b1 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java @@ -42,6 +42,7 @@ public class DataUtils { chatParseReq.setQueryText(query); chatParseReq.setChatId(id); chatParseReq.setUser(user_test); + chatParseReq.setDisableLLM(true); return chatParseReq; }