[improvement][chat]Restructure Agent&Tool package

This commit is contained in:
jerryjzhang
2023-11-29 16:34:52 +08:00
parent c11a242f34
commit 57f7d0c67d
22 changed files with 89 additions and 91 deletions

View File

@@ -5,9 +5,9 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.LLMParserTool;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.LLMParserTool;
import com.tencent.supersonic.chat.agent.RuleParserTool;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
@@ -411,8 +411,8 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数",
"对比alice和lucy的停留时长", "超音数访问次数最高的部门"));
AgentConfig agentConfig = new AgentConfig();
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setModelIds(Lists.newArrayList(-1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
@@ -420,7 +420,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL);
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(llmParserTool);
@@ -437,16 +437,16 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人"));
AgentConfig agentConfig = new AgentConfig();
RuleQueryTool ruleQueryTool = new RuleQueryTool();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.RULE);
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setModelIds(Lists.newArrayList(-1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
agentConfig.getTools().add(ruleQueryTool);
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL);
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(llmParserTool);
@@ -468,7 +468,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL);
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(5L, 6L, 7L, 8L));
agentConfig.getTools().add(llmParserTool);

View File

@@ -5,10 +5,10 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.DataAnalyticsTool;
import com.tencent.supersonic.chat.agent.PluginTool;
import com.tencent.supersonic.chat.agent.RuleParserTool;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
@@ -161,9 +161,9 @@ public class DataUtils {
return agent;
}
private static RuleQueryTool getRuleQueryTool() {
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
private static RuleParserTool getRuleQueryTool() {
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setModelIds(Lists.newArrayList(1L, 2L));
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL"));
return ruleQueryTool;
@@ -176,10 +176,10 @@ public class DataUtils {
return pluginTool;
}
private static MetricInterpretTool getMetricInterpretTool() {
MetricInterpretTool metricInterpretTool = new MetricInterpretTool();
private static DataAnalyticsTool getMetricInterpretTool() {
DataAnalyticsTool metricInterpretTool = new DataAnalyticsTool();
metricInterpretTool.setModelId(1L);
metricInterpretTool.setType(AgentToolType.INTERPRET);
metricInterpretTool.setType(AgentToolType.ANALYTICS);
metricInterpretTool.setMetricOptions(Lists.newArrayList(
new MetricOption(1L),
new MetricOption(2L),