[improvement][chat]Merge RuleTool and LLMTool.#1766 (#1768)

This commit is contained in:
Jun Zhang
2024-10-10 14:53:54 +08:00
committed by GitHub
parent fe3f4c36b5
commit 85552fb73a
14 changed files with 44 additions and 114 deletions

View File

@@ -4,7 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
@@ -335,11 +335,11 @@ public class DuSQLDemo extends S2BaseDemo {
agent.setExamples(Lists.newArrayList());
ToolConfig toolConfig = new ToolConfig();
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(4L));
toolConfig.getTools().add(llmParserTool);
DatasetTool datasetTool = new DatasetTool();
datasetTool.setId("1");
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(4L));
toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
log.info("agent:{}", JsonUtil.toString(agent));

View File

@@ -4,8 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
@@ -169,19 +168,11 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(llmParserTool);
}
DatasetTool datasetTool = new DatasetTool();
datasetTool.setId("1");
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
agentService.createAgent(agent, defaultUser);
}

View File

@@ -8,9 +8,8 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
@@ -151,18 +150,12 @@ public class S2VisitsDemo extends S2BaseDemo {
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
// configure tools
ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(llmParserTool);
}
DatasetTool datasetTool = new DatasetTool();
datasetTool.setId("1");
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat models
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();

View File

@@ -148,11 +148,11 @@ public class Text2SQLEval extends BaseTest {
return agent;
}
private static LLMParserTool getLLMQueryTool() {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(-1L));
private static DatasetTool getLLMQueryTool() {
DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(-1L));
return llmParserTool;
return datasetTool;
}
}

View File

@@ -42,6 +42,7 @@ public class DataUtils {
chatParseReq.setQueryText(query);
chatParseReq.setChatId(id);
chatParseReq.setUser(user_test);
chatParseReq.setDisableLLM(true);
return chatParseReq;
}