From 57f7d0c67d83a1092f932669caa8f63cf2ed0953 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Wed, 29 Nov 2023 16:34:52 +0800 Subject: [PATCH] [improvement][chat]Restructure Agent&Tool package --- .../tencent/supersonic/chat/agent/Agent.java | 1 - .../supersonic/chat/agent/AgentConfig.java | 1 - .../chat/agent/{tool => }/AgentTool.java | 2 +- .../supersonic/chat/agent/AgentToolType.java | 8 ++++++++ ...erpretTool.java => DataAnalyticsTool.java} | 4 ++-- .../supersonic/chat/agent/LLMParserTool.java | 12 +++++++++++ .../CommonAgentTool.java => NL2SQLTool.java} | 4 ++-- .../chat/agent/{tool => }/PluginTool.java | 2 +- ...RuleQueryTool.java => RuleParserTool.java} | 4 ++-- .../chat/agent/tool/AgentToolType.java | 8 -------- .../chat/agent/tool/LLMParserTool.java | 12 ----------- .../llm/interpret/MetricInterpretParser.java | 20 +++++++++---------- .../parser/llm/s2sql/LLMRequestService.java | 14 ++++++------- .../parser/llm/s2sql/LLMResponseService.java | 4 ++-- .../chat/parser/llm/s2sql/LLMS2SQLParser.java | 4 ++-- .../chat/parser/llm/s2sql/ParseResult.java | 4 ++-- .../chat/parser/rule/AgentCheckParser.java | 14 ++++++------- .../supersonic/chat/plugin/PluginManager.java | 4 ++-- .../supersonic/chat/service/AgentService.java | 6 +++--- .../chat/service/impl/AgentServiceImpl.java | 12 +++++------ .../com/tencent/supersonic/ConfigureDemo.java | 20 +++++++++---------- .../tencent/supersonic/util/DataUtils.java | 20 +++++++++---------- 22 files changed, 89 insertions(+), 91 deletions(-) rename chat/core/src/main/java/com/tencent/supersonic/chat/agent/{tool => }/AgentTool.java (83%) create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentToolType.java rename chat/core/src/main/java/com/tencent/supersonic/chat/agent/{tool/MetricInterpretTool.java => DataAnalyticsTool.java} (66%) create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/LLMParserTool.java rename chat/core/src/main/java/com/tencent/supersonic/chat/agent/{tool/CommonAgentTool.java => NL2SQLTool.java} (66%) rename chat/core/src/main/java/com/tencent/supersonic/chat/agent/{tool => }/PluginTool.java (73%) rename chat/core/src/main/java/com/tencent/supersonic/chat/agent/{tool/RuleQueryTool.java => RuleParserTool.java} (75%) delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentToolType.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/LLMParserTool.java diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/Agent.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/Agent.java index a747b53b8..0db60b10c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/Agent.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/Agent.java @@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.agent; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.common.pojo.RecordInfo; import java.util.Objects; import lombok.Data; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentConfig.java index 9f675cead..233f62e14 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentConfig.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.chat.agent; import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.agent.tool.AgentTool; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentTool.java similarity index 83% rename from chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentTool.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentTool.java index ff5c59029..102362bc3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentTool.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentTool.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.agent.tool; +package com.tencent.supersonic.chat.agent; import lombok.AllArgsConstructor; import lombok.Data; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentToolType.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentToolType.java new file mode 100644 index 000000000..0fd1a95c4 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentToolType.java @@ -0,0 +1,8 @@ +package com.tencent.supersonic.chat.agent; + +public enum AgentToolType { + NL2SQL_RULE, + NL2SQL_LLM, + PLUGIN, + ANALYTICS +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/MetricInterpretTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/DataAnalyticsTool.java similarity index 66% rename from chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/MetricInterpretTool.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/agent/DataAnalyticsTool.java index 4c71f8a87..b86ee44d9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/MetricInterpretTool.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/DataAnalyticsTool.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.agent.tool; +package com.tencent.supersonic.chat.agent; import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption; import lombok.Data; @@ -7,7 +7,7 @@ import java.util.List; @Data -public class MetricInterpretTool extends AgentTool { +public class DataAnalyticsTool extends AgentTool { private Long modelId; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/LLMParserTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/LLMParserTool.java new file mode 100644 index 000000000..7d498f3ff --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/LLMParserTool.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.chat.agent; + +import lombok.Data; + +import java.util.List; + +@Data +public class LLMParserTool extends NL2SQLTool { + + private List exampleQuestions; + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/CommonAgentTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/NL2SQLTool.java similarity index 66% rename from chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/CommonAgentTool.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/agent/NL2SQLTool.java index 7c90e9be0..598b1a0e4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/CommonAgentTool.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/NL2SQLTool.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.agent.tool; +package com.tencent.supersonic.chat.agent; import java.util.List; @@ -9,7 +9,7 @@ import lombok.NoArgsConstructor; @Data @NoArgsConstructor @AllArgsConstructor -public class CommonAgentTool extends AgentTool { +public class NL2SQLTool extends AgentTool { protected List modelIds; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/PluginTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/PluginTool.java similarity index 73% rename from chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/PluginTool.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/agent/PluginTool.java index 8ccb2671e..fc9cf0a76 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/PluginTool.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/PluginTool.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.agent.tool; +package com.tencent.supersonic.chat.agent; import lombok.Data; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/RuleParserTool.java similarity index 75% rename from chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/agent/RuleParserTool.java index 03c4e4ef5..53eced6ed 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/RuleParserTool.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.agent.tool; +package com.tencent.supersonic.chat.agent; import lombok.Data; @@ -7,7 +7,7 @@ import org.apache.commons.collections.CollectionUtils; import java.util.List; @Data -public class RuleQueryTool extends CommonAgentTool { +public class RuleParserTool extends NL2SQLTool { private List queryModes; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentToolType.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentToolType.java deleted file mode 100644 index 305bb4f96..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentToolType.java +++ /dev/null @@ -1,8 +0,0 @@ -package com.tencent.supersonic.chat.agent.tool; - -public enum AgentToolType { - RULE, - LLM_S2SQL, - PLUGIN, - INTERPRET -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/LLMParserTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/LLMParserTool.java deleted file mode 100644 index dd57a1833..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/LLMParserTool.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.tencent.supersonic.chat.agent.tool; - -import lombok.Data; - -import java.util.List; - -@Data -public class LLMParserTool extends CommonAgentTool { - - private List exampleQuestions; - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricInterpretParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricInterpretParser.java index fea2c1d28..c406b6dfb 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricInterpretParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricInterpretParser.java @@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.parser.llm.interpret; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Sets; import com.tencent.supersonic.chat.agent.Agent; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; -import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool; +import com.tencent.supersonic.chat.agent.AgentToolType; +import com.tencent.supersonic.chat.agent.DataAnalyticsTool; import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; @@ -42,7 +42,7 @@ public class MetricInterpretParser implements SemanticParser { log.info("skip MetricInterpretParser"); return; } - Map metricInterpretToolMap = + Map metricInterpretToolMap = getMetricInterpretTools(queryContext.getRequest().getAgentId()); log.info("metric interpret tool : {}", metricInterpretToolMap); if (CollectionUtils.isEmpty(metricInterpretToolMap)) { @@ -50,7 +50,7 @@ public class MetricInterpretParser implements SemanticParser { } Map> elementMatches = queryContext.getMapInfo().getModelElementMatches(); for (Long modelId : elementMatches.keySet()) { - MetricInterpretTool metricInterpretTool = metricInterpretToolMap.get(modelId); + DataAnalyticsTool metricInterpretTool = metricInterpretToolMap.get(modelId); if (metricInterpretTool == null) { continue; } @@ -86,22 +86,22 @@ public class MetricInterpretParser implements SemanticParser { .collect(Collectors.toSet()); } - private Map getMetricInterpretTools(Integer agentId) { + private Map getMetricInterpretTools(Integer agentId) { AgentService agentService = ContextUtils.getBean(AgentService.class); Agent agent = agentService.getAgent(agentId); if (agent == null) { return new HashMap<>(); } - List tools = agent.getTools(AgentToolType.INTERPRET); + List tools = agent.getTools(AgentToolType.ANALYTICS); if (CollectionUtils.isEmpty(tools)) { return new HashMap<>(); } - List metricInterpretTools = tools.stream().map(tool -> - JSONObject.parseObject(tool, MetricInterpretTool.class)) + List metricInterpretTools = tools.stream().map(tool -> + JSONObject.parseObject(tool, DataAnalyticsTool.class)) .filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions())) .collect(Collectors.toList()); - Map metricInterpretToolMap = new HashMap<>(); - for (MetricInterpretTool metricInterpretTool : metricInterpretTools) { + Map metricInterpretToolMap = new HashMap<>(); + for (DataAnalyticsTool metricInterpretTool : metricInterpretTools) { metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(), metricInterpretTool); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java index 78fe5853a..9346d61dc 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.chat.parser.llm.s2sql; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; -import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; +import com.tencent.supersonic.chat.agent.AgentToolType; +import com.tencent.supersonic.chat.agent.NL2SQLTool; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; @@ -74,7 +74,7 @@ public class LLMRequestService { } public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) { - Set distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL); + Set distinctModelIds = agentService.getModelIds(agentId, AgentToolType.NL2SQL_LLM); if (agentService.containsAllModel(distinctModelIds)) { distinctModelIds = new HashSet<>(); } @@ -84,10 +84,10 @@ public class LLMRequestService { return ModelCluster.build(modelCluster); } - public CommonAgentTool getParserTool(QueryReq request, Set modelIdSet) { - List commonAgentTools = agentService.getParserTools(request.getAgentId(), - AgentToolType.LLM_S2SQL); - Optional llmParserTool = commonAgentTools.stream() + public NL2SQLTool getParserTool(QueryReq request, Set modelIdSet) { + List commonAgentTools = agentService.getParserTools(request.getAgentId(), + AgentToolType.NL2SQL_LLM); + Optional llmParserTool = commonAgentTools.stream() .filter(tool -> { List modelIds = tool.getModelIds(); if (agentService.containsAllModel(new HashSet<>(modelIds))) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseService.java index 601d42b59..1843332e1 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseService.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.chat.parser.llm.s2sql; -import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; +import com.tencent.supersonic.chat.agent.NL2SQLTool; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.query.QueryManager; @@ -28,7 +28,7 @@ public class LLMResponseService { LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE); SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); parseInfo.setModel(parseResult.getModelCluster()); - CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool(); + NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool(); parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo() .getMatchedElements(parseInfo.getModelClusterKey())); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java index db16f3c3c..5b2db2c00 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.chat.parser.llm.s2sql; -import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; +import com.tencent.supersonic.chat.agent.NL2SQLTool; import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; @@ -39,7 +39,7 @@ public class LLMS2SQLParser implements SemanticParser { return; } //3.get agent tool and determine whether to skip this parser. - CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds()); + NL2SQLTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds()); if (Objects.isNull(commonAgentTool)) { log.info("no tool in this agent, skip {}", LLMS2SQLParser.class); return; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/ParseResult.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/ParseResult.java index 0f50db29e..d7aec0ec9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/ParseResult.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/ParseResult.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.chat.parser.llm.s2sql; -import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; +import com.tencent.supersonic.chat.agent.NL2SQLTool; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; @@ -27,7 +27,7 @@ public class ParseResult { private QueryReq request; - private CommonAgentTool commonAgentTool; + private NL2SQLTool commonAgentTool; private List linkingValues; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java index 29bff4be0..3daa975f2 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java @@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.parser.rule; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.agent.Agent; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; -import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; +import com.tencent.supersonic.chat.agent.AgentToolType; +import com.tencent.supersonic.chat.agent.RuleParserTool; import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.ChatContext; @@ -35,7 +35,7 @@ public class AgentCheckParser implements SemanticParser { if (agent == null) { return; } - List queryTools = getRuleTools(agentId); + List queryTools = getRuleTools(agentId); if (CollectionUtils.isEmpty(queryTools)) { queries.clear(); return; @@ -43,7 +43,7 @@ public class AgentCheckParser implements SemanticParser { log.info("queries resolved:{} {}", agent.getName(), queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); queries.removeIf(query -> { - for (RuleQueryTool tool : queryTools) { + for (RuleParserTool tool : queryTools) { if (CollectionUtils.isNotEmpty(tool.getQueryModes()) && !tool.getQueryModes().contains(query.getQueryMode())) { return true; @@ -73,17 +73,17 @@ public class AgentCheckParser implements SemanticParser { queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); } - private static List getRuleTools(Integer agentId) { + private static List getRuleTools(Integer agentId) { AgentService agentService = ContextUtils.getBean(AgentService.class); Agent agent = agentService.getAgent(agentId); if (agent == null) { return Lists.newArrayList(); } - List tools = agent.getTools(AgentToolType.RULE); + List tools = agent.getTools(AgentToolType.NL2SQL_RULE); if (CollectionUtils.isEmpty(tools)) { return Lists.newArrayList(); } - return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleQueryTool.class)) + return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleParserTool.class)) .collect(Collectors.toList()); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java b/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java index 8fcb9da4a..f029fae66 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java @@ -9,8 +9,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.agent.Agent; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; -import com.tencent.supersonic.chat.agent.tool.PluginTool; +import com.tencent.supersonic.chat.agent.AgentToolType; +import com.tencent.supersonic.chat.agent.PluginTool; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp; import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/AgentService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/AgentService.java index 5b46dcb8f..6ca47c236 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/AgentService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/AgentService.java @@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.service; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.agent.Agent; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; -import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; +import com.tencent.supersonic.chat.agent.AgentToolType; +import com.tencent.supersonic.chat.agent.NL2SQLTool; import java.util.List; import java.util.Set; @@ -19,7 +19,7 @@ public interface AgentService { void deleteAgent(Integer id); - List getParserTools(Integer agentId, AgentToolType agentToolType); + List getParserTools(Integer agentId, AgentToolType agentToolType); Set getModelIds(Integer agentId, AgentToolType agentToolType); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/AgentServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/AgentServiceImpl.java index d3e3a649f..ed7a83e09 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/AgentServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/AgentServiceImpl.java @@ -4,8 +4,8 @@ import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.agent.Agent; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; -import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; +import com.tencent.supersonic.chat.agent.AgentToolType; +import com.tencent.supersonic.chat.agent.NL2SQLTool; import com.tencent.supersonic.chat.persistence.dataobject.AgentDO; import com.tencent.supersonic.chat.persistence.repository.AgentRepository; import com.tencent.supersonic.chat.service.AgentService; @@ -87,7 +87,7 @@ public class AgentServiceImpl implements AgentService { return agentDO; } - public List getParserTools(Integer agentId, AgentToolType agentToolType) { + public List getParserTools(Integer agentId, AgentToolType agentToolType) { Agent agent = getAgent(agentId); if (agent == null) { return Lists.newArrayList(); @@ -96,16 +96,16 @@ public class AgentServiceImpl implements AgentService { if (CollectionUtils.isEmpty(tools)) { return Lists.newArrayList(); } - return tools.stream().map(tool -> JSONObject.parseObject(tool, CommonAgentTool.class)) + return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class)) .collect(Collectors.toList()); } public Set getModelIds(Integer agentId, AgentToolType agentToolType) { - List commonAgentTools = getParserTools(agentId, agentToolType); + List commonAgentTools = getParserTools(agentId, agentToolType); if (CollectionUtils.isEmpty(commonAgentTools)) { return new HashSet<>(); } - return commonAgentTools.stream().map(CommonAgentTool::getModelIds) + return commonAgentTools.stream().map(NL2SQLTool::getModelIds) .filter(modelIds -> !CollectionUtils.isEmpty(modelIds)) .flatMap(Collection::stream) .collect(Collectors.toSet()); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java index fdcd676a6..98be626bd 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java @@ -5,9 +5,9 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.AgentConfig; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; -import com.tencent.supersonic.chat.agent.tool.LLMParserTool; -import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; +import com.tencent.supersonic.chat.agent.AgentToolType; +import com.tencent.supersonic.chat.agent.LLMParserTool; +import com.tencent.supersonic.chat.agent.RuleParserTool; import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq; import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq; @@ -411,8 +411,8 @@ public class ConfigureDemo implements ApplicationListener agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长", "超音数访问次数最高的部门")); AgentConfig agentConfig = new AgentConfig(); - RuleQueryTool ruleQueryTool = new RuleQueryTool(); - ruleQueryTool.setType(AgentToolType.RULE); + RuleParserTool ruleQueryTool = new RuleParserTool(); + ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setId("0"); ruleQueryTool.setModelIds(Lists.newArrayList(-1L)); ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name())); @@ -420,7 +420,7 @@ public class ConfigureDemo implements ApplicationListener LLMParserTool llmParserTool = new LLMParserTool(); llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.LLM_S2SQL); + llmParserTool.setType(AgentToolType.NL2SQL_LLM); llmParserTool.setModelIds(Lists.newArrayList(-1L)); agentConfig.getTools().add(llmParserTool); @@ -437,16 +437,16 @@ public class ConfigureDemo implements ApplicationListener agent.setEnableSearch(1); agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人")); AgentConfig agentConfig = new AgentConfig(); - RuleQueryTool ruleQueryTool = new RuleQueryTool(); + RuleParserTool ruleQueryTool = new RuleParserTool(); ruleQueryTool.setId("0"); - ruleQueryTool.setType(AgentToolType.RULE); + ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setModelIds(Lists.newArrayList(-1L)); ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name())); agentConfig.getTools().add(ruleQueryTool); LLMParserTool llmParserTool = new LLMParserTool(); llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.LLM_S2SQL); + llmParserTool.setType(AgentToolType.NL2SQL_LLM); llmParserTool.setModelIds(Lists.newArrayList(-1L)); agentConfig.getTools().add(llmParserTool); @@ -468,7 +468,7 @@ public class ConfigureDemo implements ApplicationListener LLMParserTool llmParserTool = new LLMParserTool(); llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.LLM_S2SQL); + llmParserTool.setType(AgentToolType.NL2SQL_LLM); llmParserTool.setModelIds(Lists.newArrayList(5L, 6L, 7L, 8L)); agentConfig.getTools().add(llmParserTool); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java index 4d5c576de..4361bc9a9 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java @@ -5,10 +5,10 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.AgentConfig; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; -import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool; -import com.tencent.supersonic.chat.agent.tool.PluginTool; -import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; +import com.tencent.supersonic.chat.agent.AgentToolType; +import com.tencent.supersonic.chat.agent.DataAnalyticsTool; +import com.tencent.supersonic.chat.agent.PluginTool; +import com.tencent.supersonic.chat.agent.RuleParserTool; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; @@ -161,9 +161,9 @@ public class DataUtils { return agent; } - private static RuleQueryTool getRuleQueryTool() { - RuleQueryTool ruleQueryTool = new RuleQueryTool(); - ruleQueryTool.setType(AgentToolType.RULE); + private static RuleParserTool getRuleQueryTool() { + RuleParserTool ruleQueryTool = new RuleParserTool(); + ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setModelIds(Lists.newArrayList(1L, 2L)); ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL")); return ruleQueryTool; @@ -176,10 +176,10 @@ public class DataUtils { return pluginTool; } - private static MetricInterpretTool getMetricInterpretTool() { - MetricInterpretTool metricInterpretTool = new MetricInterpretTool(); + private static DataAnalyticsTool getMetricInterpretTool() { + DataAnalyticsTool metricInterpretTool = new DataAnalyticsTool(); metricInterpretTool.setModelId(1L); - metricInterpretTool.setType(AgentToolType.INTERPRET); + metricInterpretTool.setType(AgentToolType.ANALYTICS); metricInterpretTool.setMetricOptions(Lists.newArrayList( new MetricOption(1L), new MetricOption(2L),