mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-18 08:17:18 +00:00
[improvement][chat]Merge RuleTool and LLMTool.#1766 (#1768)
This commit is contained in:
@@ -4,7 +4,7 @@ import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.common.pojo.JoinCondition;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
@@ -335,11 +335,11 @@ public class DuSQLDemo extends S2BaseDemo {
|
||||
agent.setExamples(Lists.newArrayList());
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
|
||||
LLMParserTool llmParserTool = new LLMParserTool();
|
||||
llmParserTool.setId("1");
|
||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
llmParserTool.setDataSetIds(Lists.newArrayList(4L));
|
||||
toolConfig.getTools().add(llmParserTool);
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setId("1");
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(4L));
|
||||
toolConfig.getTools().add(datasetTool);
|
||||
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
log.info("agent:{}", JsonUtil.toString(agent));
|
||||
|
||||
@@ -4,8 +4,7 @@ import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
@@ -169,19 +168,11 @@ public class S2ArtistDemo extends S2BaseDemo {
|
||||
agent.setEnableSearch(1);
|
||||
agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||
ruleQueryTool.setId("0");
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
toolConfig.getTools().add(ruleQueryTool);
|
||||
|
||||
if (demoEnableLlm) {
|
||||
LLMParserTool llmParserTool = new LLMParserTool();
|
||||
llmParserTool.setId("1");
|
||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
toolConfig.getTools().add(llmParserTool);
|
||||
}
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setId("1");
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
toolConfig.getTools().add(datasetTool);
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
agentService.createAgent(agent, defaultUser);
|
||||
}
|
||||
|
||||
@@ -8,9 +8,8 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
@@ -151,18 +150,12 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
|
||||
// configure tools
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
||||
ruleQueryTool.setId("0");
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
toolConfig.getTools().add(ruleQueryTool);
|
||||
if (demoEnableLlm) {
|
||||
LLMParserTool llmParserTool = new LLMParserTool();
|
||||
llmParserTool.setId("1");
|
||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
toolConfig.getTools().add(llmParserTool);
|
||||
}
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setId("1");
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
toolConfig.getTools().add(datasetTool);
|
||||
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
// configure chat models
|
||||
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
|
||||
|
||||
@@ -148,11 +148,11 @@ public class Text2SQLEval extends BaseTest {
|
||||
return agent;
|
||||
}
|
||||
|
||||
private static LLMParserTool getLLMQueryTool() {
|
||||
LLMParserTool llmParserTool = new LLMParserTool();
|
||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
llmParserTool.setDataSetIds(Lists.newArrayList(-1L));
|
||||
private static DatasetTool getLLMQueryTool() {
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(-1L));
|
||||
|
||||
return llmParserTool;
|
||||
return datasetTool;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ public class DataUtils {
|
||||
chatParseReq.setQueryText(query);
|
||||
chatParseReq.setChatId(id);
|
||||
chatParseReq.setUser(user_test);
|
||||
chatParseReq.setDisableLLM(true);
|
||||
return chatParseReq;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user