[improvement][chat]Restructure Agent&Tool package

This commit is contained in:
jerryjzhang
2023-11-29 16:34:52 +08:00
parent c11a242f34
commit 57f7d0c67d
22 changed files with 89 additions and 91 deletions

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.agent;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import java.util.Objects; import java.util.Objects;
import lombok.Data; import lombok.Data;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.agent; package com.tencent.supersonic.chat.agent;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentTool;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool; package com.tencent.supersonic.chat.agent;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;

View File

@@ -0,0 +1,8 @@
package com.tencent.supersonic.chat.agent;
public enum AgentToolType {
NL2SQL_RULE,
NL2SQL_LLM,
PLUGIN,
ANALYTICS
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool; package com.tencent.supersonic.chat.agent;
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption; import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
import lombok.Data; import lombok.Data;
@@ -7,7 +7,7 @@ import java.util.List;
@Data @Data
public class MetricInterpretTool extends AgentTool { public class DataAnalyticsTool extends AgentTool {
private Long modelId; private Long modelId;

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.agent;
import lombok.Data;
import java.util.List;
@Data
public class LLMParserTool extends NL2SQLTool {
private List<String> exampleQuestions;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool; package com.tencent.supersonic.chat.agent;
import java.util.List; import java.util.List;
@@ -9,7 +9,7 @@ import lombok.NoArgsConstructor;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public class CommonAgentTool extends AgentTool { public class NL2SQLTool extends AgentTool {
protected List<Long> modelIds; protected List<Long> modelIds;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool; package com.tencent.supersonic.chat.agent;
import lombok.Data; import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool; package com.tencent.supersonic.chat.agent;
import lombok.Data; import lombok.Data;
@@ -7,7 +7,7 @@ import org.apache.commons.collections.CollectionUtils;
import java.util.List; import java.util.List;
@Data @Data
public class RuleQueryTool extends CommonAgentTool { public class RuleParserTool extends NL2SQLTool {
private List<String> queryModes; private List<String> queryModes;

View File

@@ -1,8 +0,0 @@
package com.tencent.supersonic.chat.agent.tool;
public enum AgentToolType {
RULE,
LLM_S2SQL,
PLUGIN,
INTERPRET
}

View File

@@ -1,12 +0,0 @@
package com.tencent.supersonic.chat.agent.tool;
import lombok.Data;
import java.util.List;
@Data
public class LLMParserTool extends CommonAgentTool {
private List<String> exampleQuestions;
}

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.parser.llm.interpret;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool; import com.tencent.supersonic.chat.agent.DataAnalyticsTool;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -42,7 +42,7 @@ public class MetricInterpretParser implements SemanticParser {
log.info("skip MetricInterpretParser"); log.info("skip MetricInterpretParser");
return; return;
} }
Map<Long, MetricInterpretTool> metricInterpretToolMap = Map<Long, DataAnalyticsTool> metricInterpretToolMap =
getMetricInterpretTools(queryContext.getRequest().getAgentId()); getMetricInterpretTools(queryContext.getRequest().getAgentId());
log.info("metric interpret tool : {}", metricInterpretToolMap); log.info("metric interpret tool : {}", metricInterpretToolMap);
if (CollectionUtils.isEmpty(metricInterpretToolMap)) { if (CollectionUtils.isEmpty(metricInterpretToolMap)) {
@@ -50,7 +50,7 @@ public class MetricInterpretParser implements SemanticParser {
} }
Map<Long, List<SchemaElementMatch>> elementMatches = queryContext.getMapInfo().getModelElementMatches(); Map<Long, List<SchemaElementMatch>> elementMatches = queryContext.getMapInfo().getModelElementMatches();
for (Long modelId : elementMatches.keySet()) { for (Long modelId : elementMatches.keySet()) {
MetricInterpretTool metricInterpretTool = metricInterpretToolMap.get(modelId); DataAnalyticsTool metricInterpretTool = metricInterpretToolMap.get(modelId);
if (metricInterpretTool == null) { if (metricInterpretTool == null) {
continue; continue;
} }
@@ -86,22 +86,22 @@ public class MetricInterpretParser implements SemanticParser {
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
private Map<Long, MetricInterpretTool> getMetricInterpretTools(Integer agentId) { private Map<Long, DataAnalyticsTool> getMetricInterpretTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class); AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId); Agent agent = agentService.getAgent(agentId);
if (agent == null) { if (agent == null) {
return new HashMap<>(); return new HashMap<>();
} }
List<String> tools = agent.getTools(AgentToolType.INTERPRET); List<String> tools = agent.getTools(AgentToolType.ANALYTICS);
if (CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(tools)) {
return new HashMap<>(); return new HashMap<>();
} }
List<MetricInterpretTool> metricInterpretTools = tools.stream().map(tool -> List<DataAnalyticsTool> metricInterpretTools = tools.stream().map(tool ->
JSONObject.parseObject(tool, MetricInterpretTool.class)) JSONObject.parseObject(tool, DataAnalyticsTool.class))
.filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions())) .filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions()))
.collect(Collectors.toList()); .collect(Collectors.toList());
Map<Long, MetricInterpretTool> metricInterpretToolMap = new HashMap<>(); Map<Long, DataAnalyticsTool> metricInterpretToolMap = new HashMap<>();
for (MetricInterpretTool metricInterpretTool : metricInterpretTools) { for (DataAnalyticsTool metricInterpretTool : metricInterpretTools) {
metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(), metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(),
metricInterpretTool); metricInterpretTool);
} }

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.parser.llm.s2sql; package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -74,7 +74,7 @@ public class LLMRequestService {
} }
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) { public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL); Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.NL2SQL_LLM);
if (agentService.containsAllModel(distinctModelIds)) { if (agentService.containsAllModel(distinctModelIds)) {
distinctModelIds = new HashSet<>(); distinctModelIds = new HashSet<>();
} }
@@ -84,10 +84,10 @@ public class LLMRequestService {
return ModelCluster.build(modelCluster); return ModelCluster.build(modelCluster);
} }
public CommonAgentTool getParserTool(QueryReq request, Set<Long> modelIdSet) { public NL2SQLTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(), List<NL2SQLTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
AgentToolType.LLM_S2SQL); AgentToolType.NL2SQL_LLM);
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream() Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
.filter(tool -> { .filter(tool -> {
List<Long> modelIds = tool.getModelIds(); List<Long> modelIds = tool.getModelIds();
if (agentService.containsAllModel(new HashSet<>(modelIds))) { if (agentService.containsAllModel(new HashSet<>(modelIds))) {

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.parser.llm.s2sql; package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
@@ -28,7 +28,7 @@ public class LLMResponseService {
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE); LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.setModel(parseResult.getModelCluster()); parseInfo.setModel(parseResult.getModelCluster());
CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool(); NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo() parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
.getMatchedElements(parseInfo.getModelClusterKey())); .getMatchedElements(parseInfo.getModelClusterKey()));

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.parser.llm.s2sql; package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -39,7 +39,7 @@ public class LLMS2SQLParser implements SemanticParser {
return; return;
} }
//3.get agent tool and determine whether to skip this parser. //3.get agent tool and determine whether to skip this parser.
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds()); NL2SQLTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds());
if (Objects.isNull(commonAgentTool)) { if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMS2SQLParser.class); log.info("no tool in this agent, skip {}", LLMS2SQLParser.class);
return; return;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.parser.llm.s2sql; package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
@@ -27,7 +27,7 @@ public class ParseResult {
private QueryReq request; private QueryReq request;
private CommonAgentTool commonAgentTool; private NL2SQLTool commonAgentTool;
private List<ElementValue> linkingValues; private List<ElementValue> linkingValues;
} }

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.parser.rule;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; import com.tencent.supersonic.chat.agent.RuleParserTool;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -35,7 +35,7 @@ public class AgentCheckParser implements SemanticParser {
if (agent == null) { if (agent == null) {
return; return;
} }
List<RuleQueryTool> queryTools = getRuleTools(agentId); List<RuleParserTool> queryTools = getRuleTools(agentId);
if (CollectionUtils.isEmpty(queryTools)) { if (CollectionUtils.isEmpty(queryTools)) {
queries.clear(); queries.clear();
return; return;
@@ -43,7 +43,7 @@ public class AgentCheckParser implements SemanticParser {
log.info("queries resolved:{} {}", agent.getName(), log.info("queries resolved:{} {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
queries.removeIf(query -> { queries.removeIf(query -> {
for (RuleQueryTool tool : queryTools) { for (RuleParserTool tool : queryTools) {
if (CollectionUtils.isNotEmpty(tool.getQueryModes()) if (CollectionUtils.isNotEmpty(tool.getQueryModes())
&& !tool.getQueryModes().contains(query.getQueryMode())) { && !tool.getQueryModes().contains(query.getQueryMode())) {
return true; return true;
@@ -73,17 +73,17 @@ public class AgentCheckParser implements SemanticParser {
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
} }
private static List<RuleQueryTool> getRuleTools(Integer agentId) { private static List<RuleParserTool> getRuleTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class); AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId); Agent agent = agentService.getAgent(agentId);
if (agent == null) { if (agent == null) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
List<String> tools = agent.getTools(AgentToolType.RULE); List<String> tools = agent.getTools(AgentToolType.NL2SQL_RULE);
if (CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleQueryTool.class)) return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleParserTool.class))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }

View File

@@ -9,8 +9,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.PluginTool; import com.tencent.supersonic.chat.agent.PluginTool;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp; import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.agent.NL2SQLTool;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@@ -19,7 +19,7 @@ public interface AgentService {
void deleteAgent(Integer id); void deleteAgent(Integer id);
List<CommonAgentTool> getParserTools(Integer agentId, AgentToolType agentToolType); List<NL2SQLTool> getParserTools(Integer agentId, AgentToolType agentToolType);
Set<Long> getModelIds(Integer agentId, AgentToolType agentToolType); Set<Long> getModelIds(Integer agentId, AgentToolType agentToolType);

View File

@@ -4,8 +4,8 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO; import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.persistence.repository.AgentRepository; import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.AgentService;
@@ -87,7 +87,7 @@ public class AgentServiceImpl implements AgentService {
return agentDO; return agentDO;
} }
public List<CommonAgentTool> getParserTools(Integer agentId, AgentToolType agentToolType) { public List<NL2SQLTool> getParserTools(Integer agentId, AgentToolType agentToolType) {
Agent agent = getAgent(agentId); Agent agent = getAgent(agentId);
if (agent == null) { if (agent == null) {
return Lists.newArrayList(); return Lists.newArrayList();
@@ -96,16 +96,16 @@ public class AgentServiceImpl implements AgentService {
if (CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
return tools.stream().map(tool -> JSONObject.parseObject(tool, CommonAgentTool.class)) return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public Set<Long> getModelIds(Integer agentId, AgentToolType agentToolType) { public Set<Long> getModelIds(Integer agentId, AgentToolType agentToolType) {
List<CommonAgentTool> commonAgentTools = getParserTools(agentId, agentToolType); List<NL2SQLTool> commonAgentTools = getParserTools(agentId, agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) { if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>(); return new HashSet<>();
} }
return commonAgentTools.stream().map(CommonAgentTool::getModelIds) return commonAgentTools.stream().map(NL2SQLTool::getModelIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds)) .filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
.flatMap(Collection::stream) .flatMap(Collection::stream)
.collect(Collectors.toSet()); .collect(Collectors.toSet());

View File

@@ -5,9 +5,9 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig; import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.LLMParserTool; import com.tencent.supersonic.chat.agent.LLMParserTool;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; import com.tencent.supersonic.chat.agent.RuleParserTool;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq; import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq; import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
@@ -411,8 +411,8 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数", agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数",
"对比alice和lucy的停留时长", "超音数访问次数最高的部门")); "对比alice和lucy的停留时长", "超音数访问次数最高的部门"));
AgentConfig agentConfig = new AgentConfig(); AgentConfig agentConfig = new AgentConfig();
RuleQueryTool ruleQueryTool = new RuleQueryTool(); RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.RULE); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0"); ruleQueryTool.setId("0");
ruleQueryTool.setModelIds(Lists.newArrayList(-1L)); ruleQueryTool.setModelIds(Lists.newArrayList(-1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name())); ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
@@ -420,7 +420,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
LLMParserTool llmParserTool = new LLMParserTool(); LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1"); llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL); llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(-1L)); llmParserTool.setModelIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(llmParserTool); agentConfig.getTools().add(llmParserTool);
@@ -437,16 +437,16 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
agent.setEnableSearch(1); agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人")); agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人"));
AgentConfig agentConfig = new AgentConfig(); AgentConfig agentConfig = new AgentConfig();
RuleQueryTool ruleQueryTool = new RuleQueryTool(); RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0"); ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.RULE); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setModelIds(Lists.newArrayList(-1L)); ruleQueryTool.setModelIds(Lists.newArrayList(-1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name())); ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
agentConfig.getTools().add(ruleQueryTool); agentConfig.getTools().add(ruleQueryTool);
LLMParserTool llmParserTool = new LLMParserTool(); LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1"); llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL); llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(-1L)); llmParserTool.setModelIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(llmParserTool); agentConfig.getTools().add(llmParserTool);
@@ -468,7 +468,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
LLMParserTool llmParserTool = new LLMParserTool(); LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1"); llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL); llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(5L, 6L, 7L, 8L)); llmParserTool.setModelIds(Lists.newArrayList(5L, 6L, 7L, 8L));
agentConfig.getTools().add(llmParserTool); agentConfig.getTools().add(llmParserTool);

View File

@@ -5,10 +5,10 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig; import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool; import com.tencent.supersonic.chat.agent.DataAnalyticsTool;
import com.tencent.supersonic.chat.agent.tool.PluginTool; import com.tencent.supersonic.chat.agent.PluginTool;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; import com.tencent.supersonic.chat.agent.RuleParserTool;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
@@ -161,9 +161,9 @@ public class DataUtils {
return agent; return agent;
} }
private static RuleQueryTool getRuleQueryTool() { private static RuleParserTool getRuleQueryTool() {
RuleQueryTool ruleQueryTool = new RuleQueryTool(); RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.RULE); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setModelIds(Lists.newArrayList(1L, 2L)); ruleQueryTool.setModelIds(Lists.newArrayList(1L, 2L));
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL")); ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL"));
return ruleQueryTool; return ruleQueryTool;
@@ -176,10 +176,10 @@ public class DataUtils {
return pluginTool; return pluginTool;
} }
private static MetricInterpretTool getMetricInterpretTool() { private static DataAnalyticsTool getMetricInterpretTool() {
MetricInterpretTool metricInterpretTool = new MetricInterpretTool(); DataAnalyticsTool metricInterpretTool = new DataAnalyticsTool();
metricInterpretTool.setModelId(1L); metricInterpretTool.setModelId(1L);
metricInterpretTool.setType(AgentToolType.INTERPRET); metricInterpretTool.setType(AgentToolType.ANALYTICS);
metricInterpretTool.setMetricOptions(Lists.newArrayList( metricInterpretTool.setMetricOptions(Lists.newArrayList(
new MetricOption(1L), new MetricOption(1L),
new MetricOption(2L), new MetricOption(2L),