[improvement][chat]Restructure Agent&Tool package

This commit is contained in:
jerryjzhang
2023-11-29 16:34:52 +08:00
parent c11a242f34
commit 57f7d0c67d
22 changed files with 89 additions and 91 deletions

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.common.pojo.RecordInfo;
import java.util.Objects;
import lombok.Data;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.agent;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentTool;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool;
package com.tencent.supersonic.chat.agent;
import lombok.AllArgsConstructor;
import lombok.Data;

View File

@@ -0,0 +1,8 @@
package com.tencent.supersonic.chat.agent;
public enum AgentToolType {
NL2SQL_RULE,
NL2SQL_LLM,
PLUGIN,
ANALYTICS
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool;
package com.tencent.supersonic.chat.agent;
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
import lombok.Data;
@@ -7,7 +7,7 @@ import java.util.List;
@Data
public class MetricInterpretTool extends AgentTool {
public class DataAnalyticsTool extends AgentTool {
private Long modelId;

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.agent;
import lombok.Data;
import java.util.List;
@Data
public class LLMParserTool extends NL2SQLTool {
private List<String> exampleQuestions;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool;
package com.tencent.supersonic.chat.agent;
import java.util.List;
@@ -9,7 +9,7 @@ import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class CommonAgentTool extends AgentTool {
public class NL2SQLTool extends AgentTool {
protected List<Long> modelIds;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool;
package com.tencent.supersonic.chat.agent;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent.tool;
package com.tencent.supersonic.chat.agent;
import lombok.Data;
@@ -7,7 +7,7 @@ import org.apache.commons.collections.CollectionUtils;
import java.util.List;
@Data
public class RuleQueryTool extends CommonAgentTool {
public class RuleParserTool extends NL2SQLTool {
private List<String> queryModes;

View File

@@ -1,8 +0,0 @@
package com.tencent.supersonic.chat.agent.tool;
public enum AgentToolType {
RULE,
LLM_S2SQL,
PLUGIN,
INTERPRET
}

View File

@@ -1,12 +0,0 @@
package com.tencent.supersonic.chat.agent.tool;
import lombok.Data;
import java.util.List;
@Data
public class LLMParserTool extends CommonAgentTool {
private List<String> exampleQuestions;
}

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.parser.llm.interpret;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.DataAnalyticsTool;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -42,7 +42,7 @@ public class MetricInterpretParser implements SemanticParser {
log.info("skip MetricInterpretParser");
return;
}
Map<Long, MetricInterpretTool> metricInterpretToolMap =
Map<Long, DataAnalyticsTool> metricInterpretToolMap =
getMetricInterpretTools(queryContext.getRequest().getAgentId());
log.info("metric interpret tool : {}", metricInterpretToolMap);
if (CollectionUtils.isEmpty(metricInterpretToolMap)) {
@@ -50,7 +50,7 @@ public class MetricInterpretParser implements SemanticParser {
}
Map<Long, List<SchemaElementMatch>> elementMatches = queryContext.getMapInfo().getModelElementMatches();
for (Long modelId : elementMatches.keySet()) {
MetricInterpretTool metricInterpretTool = metricInterpretToolMap.get(modelId);
DataAnalyticsTool metricInterpretTool = metricInterpretToolMap.get(modelId);
if (metricInterpretTool == null) {
continue;
}
@@ -86,22 +86,22 @@ public class MetricInterpretParser implements SemanticParser {
.collect(Collectors.toSet());
}
private Map<Long, MetricInterpretTool> getMetricInterpretTools(Integer agentId) {
private Map<Long, DataAnalyticsTool> getMetricInterpretTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return new HashMap<>();
}
List<String> tools = agent.getTools(AgentToolType.INTERPRET);
List<String> tools = agent.getTools(AgentToolType.ANALYTICS);
if (CollectionUtils.isEmpty(tools)) {
return new HashMap<>();
}
List<MetricInterpretTool> metricInterpretTools = tools.stream().map(tool ->
JSONObject.parseObject(tool, MetricInterpretTool.class))
List<DataAnalyticsTool> metricInterpretTools = tools.stream().map(tool ->
JSONObject.parseObject(tool, DataAnalyticsTool.class))
.filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions()))
.collect(Collectors.toList());
Map<Long, MetricInterpretTool> metricInterpretToolMap = new HashMap<>();
for (MetricInterpretTool metricInterpretTool : metricInterpretTools) {
Map<Long, DataAnalyticsTool> metricInterpretToolMap = new HashMap<>();
for (DataAnalyticsTool metricInterpretTool : metricInterpretTools) {
metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(),
metricInterpretTool);
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -74,7 +74,7 @@ public class LLMRequestService {
}
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL);
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.NL2SQL_LLM);
if (agentService.containsAllModel(distinctModelIds)) {
distinctModelIds = new HashSet<>();
}
@@ -84,10 +84,10 @@ public class LLMRequestService {
return ModelCluster.build(modelCluster);
}
public CommonAgentTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
AgentToolType.LLM_S2SQL);
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream()
public NL2SQLTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
List<NL2SQLTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
AgentToolType.NL2SQL_LLM);
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
.filter(tool -> {
List<Long> modelIds = tool.getModelIds();
if (agentService.containsAllModel(new HashSet<>(modelIds))) {

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.QueryManager;
@@ -28,7 +28,7 @@ public class LLMResponseService {
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.setModel(parseResult.getModelCluster());
CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool();
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
.getMatchedElements(parseInfo.getModelClusterKey()));

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -39,7 +39,7 @@ public class LLMS2SQLParser implements SemanticParser {
return;
}
//3.get agent tool and determine whether to skip this parser.
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds());
NL2SQLTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds());
if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMS2SQLParser.class);
return;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
@@ -27,7 +27,7 @@ public class ParseResult {
private QueryReq request;
private CommonAgentTool commonAgentTool;
private NL2SQLTool commonAgentTool;
private List<ElementValue> linkingValues;
}

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.parser.rule;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.RuleParserTool;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -35,7 +35,7 @@ public class AgentCheckParser implements SemanticParser {
if (agent == null) {
return;
}
List<RuleQueryTool> queryTools = getRuleTools(agentId);
List<RuleParserTool> queryTools = getRuleTools(agentId);
if (CollectionUtils.isEmpty(queryTools)) {
queries.clear();
return;
@@ -43,7 +43,7 @@ public class AgentCheckParser implements SemanticParser {
log.info("queries resolved:{} {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
queries.removeIf(query -> {
for (RuleQueryTool tool : queryTools) {
for (RuleParserTool tool : queryTools) {
if (CollectionUtils.isNotEmpty(tool.getQueryModes())
&& !tool.getQueryModes().contains(query.getQueryMode())) {
return true;
@@ -73,17 +73,17 @@ public class AgentCheckParser implements SemanticParser {
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
}
private static List<RuleQueryTool> getRuleTools(Integer agentId) {
private static List<RuleParserTool> getRuleTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return Lists.newArrayList();
}
List<String> tools = agent.getTools(AgentToolType.RULE);
List<String> tools = agent.getTools(AgentToolType.NL2SQL_RULE);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleQueryTool.class))
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleParserTool.class))
.collect(Collectors.toList());
}

View File

@@ -9,8 +9,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.PluginTool;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.NL2SQLTool;
import java.util.List;
import java.util.Set;
@@ -19,7 +19,7 @@ public interface AgentService {
void deleteAgent(Integer id);
List<CommonAgentTool> getParserTools(Integer agentId, AgentToolType agentToolType);
List<NL2SQLTool> getParserTools(Integer agentId, AgentToolType agentToolType);
Set<Long> getModelIds(Integer agentId, AgentToolType agentToolType);

View File

@@ -4,8 +4,8 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
import com.tencent.supersonic.chat.service.AgentService;
@@ -87,7 +87,7 @@ public class AgentServiceImpl implements AgentService {
return agentDO;
}
public List<CommonAgentTool> getParserTools(Integer agentId, AgentToolType agentToolType) {
public List<NL2SQLTool> getParserTools(Integer agentId, AgentToolType agentToolType) {
Agent agent = getAgent(agentId);
if (agent == null) {
return Lists.newArrayList();
@@ -96,16 +96,16 @@ public class AgentServiceImpl implements AgentService {
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, CommonAgentTool.class))
return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
.collect(Collectors.toList());
}
public Set<Long> getModelIds(Integer agentId, AgentToolType agentToolType) {
List<CommonAgentTool> commonAgentTools = getParserTools(agentId, agentToolType);
List<NL2SQLTool> commonAgentTools = getParserTools(agentId, agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
}
return commonAgentTools.stream().map(CommonAgentTool::getModelIds)
return commonAgentTools.stream().map(NL2SQLTool::getModelIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
.flatMap(Collection::stream)
.collect(Collectors.toSet());

View File

@@ -5,9 +5,9 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.LLMParserTool;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.LLMParserTool;
import com.tencent.supersonic.chat.agent.RuleParserTool;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
@@ -411,8 +411,8 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数",
"对比alice和lucy的停留时长", "超音数访问次数最高的部门"));
AgentConfig agentConfig = new AgentConfig();
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setModelIds(Lists.newArrayList(-1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
@@ -420,7 +420,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL);
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(llmParserTool);
@@ -437,16 +437,16 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人"));
AgentConfig agentConfig = new AgentConfig();
RuleQueryTool ruleQueryTool = new RuleQueryTool();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.RULE);
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setModelIds(Lists.newArrayList(-1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
agentConfig.getTools().add(ruleQueryTool);
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL);
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(llmParserTool);
@@ -468,7 +468,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.LLM_S2SQL);
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(5L, 6L, 7L, 8L));
agentConfig.getTools().add(llmParserTool);

View File

@@ -5,10 +5,10 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.DataAnalyticsTool;
import com.tencent.supersonic.chat.agent.PluginTool;
import com.tencent.supersonic.chat.agent.RuleParserTool;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
@@ -161,9 +161,9 @@ public class DataUtils {
return agent;
}
private static RuleQueryTool getRuleQueryTool() {
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
private static RuleParserTool getRuleQueryTool() {
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setModelIds(Lists.newArrayList(1L, 2L));
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL"));
return ruleQueryTool;
@@ -176,10 +176,10 @@ public class DataUtils {
return pluginTool;
}
private static MetricInterpretTool getMetricInterpretTool() {
MetricInterpretTool metricInterpretTool = new MetricInterpretTool();
private static DataAnalyticsTool getMetricInterpretTool() {
DataAnalyticsTool metricInterpretTool = new DataAnalyticsTool();
metricInterpretTool.setModelId(1L);
metricInterpretTool.setType(AgentToolType.INTERPRET);
metricInterpretTool.setType(AgentToolType.ANALYTICS);
metricInterpretTool.setMetricOptions(Lists.newArrayList(
new MetricOption(1L),
new MetricOption(2L),