[improvement][chat]Merge RuleTool and LLMTool.#1766 (#1768)

This commit is contained in:
Jun Zhang
2024-10-10 14:53:54 +08:00
committed by GitHub
parent fe3f4c36b5
commit 85552fb73a
14 changed files with 44 additions and 114 deletions

View File

@@ -54,12 +54,12 @@ public class Agent extends RecordInfo {
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L); return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
} }
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) { public List<DatasetTool> getParserTools(AgentToolType agentToolType) {
List<String> tools = this.getTools(agentToolType); List<String> tools = this.getTools(agentToolType);
if (CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class)) return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@@ -67,17 +67,8 @@ public class Agent extends RecordInfo {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN)); return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
} }
public boolean containsLLMTool() { public boolean containsDatasetTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); return !CollectionUtils.isEmpty(getParserTools(AgentToolType.DATASET));
}
public boolean containsRuleTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
}
public boolean containsNL2SQLTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM))
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
} }
public boolean containsAnyTool() { public boolean containsAnyTool() {
@@ -102,11 +93,11 @@ public class Agent extends RecordInfo {
} }
public Set<Long> getDataSetIds(AgentToolType agentToolType) { public Set<Long> getDataSetIds(AgentToolType agentToolType) {
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType); List<DatasetTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) { if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>(); return new HashSet<>();
} }
return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds) return commonAgentTools.stream().map(DatasetTool::getDataSetIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream) .filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }

View File

@@ -4,9 +4,9 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
public enum AgentToolType { public enum AgentToolType {
NL2SQL_RULE("基于规则Text-to-SQL"), NL2SQL_LLM("基于大模型Text-to-SQL"), PLUGIN("第三方插件"); DATASET("Text2SQL数据集"), PLUGIN("第三方插件");
private String title; private final String title;
AgentToolType(String title) { AgentToolType(String title) {
this.title = title; this.title = title;
@@ -14,8 +14,7 @@ public enum AgentToolType {
public static Map<AgentToolType, String> getToolTypes() { public static Map<AgentToolType, String> getToolTypes() {
Map<AgentToolType, String> map = new HashMap<>(); Map<AgentToolType, String> map = new HashMap<>();
map.put(NL2SQL_RULE, NL2SQL_RULE.title); map.put(DATASET, DATASET.title);
map.put(NL2SQL_LLM, NL2SQL_LLM.title);
map.put(PLUGIN, PLUGIN.title); map.put(PLUGIN, PLUGIN.title);
return map; return map;
} }

View File

@@ -9,7 +9,8 @@ import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public class NL2SQLTool extends AgentTool { public class DatasetTool extends AgentTool {
protected List<Long> dataSetIds; private List<Long> dataSetIds;
private List<String> exampleQuestions;
} }

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.Data;
import java.util.List;
@Data
public class LLMParserTool extends NL2SQLTool {
private List<String> exampleQuestions;
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import java.util.List;
@Data
public class RuleParserTool extends NL2SQLTool {
private List<String> queryModes;
private List<String> queryTypes;
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L);
}
}

View File

@@ -85,7 +85,7 @@ public class NL2SQLParser implements ChatQueryParser {
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId()); ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
if (parseContext.enbaleLLM()) { if (!parseContext.isDisableLLM()) {
processMultiTurn(parseContext); processMultiTurn(parseContext);
} }
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
@@ -96,7 +96,7 @@ public class NL2SQLParser implements ChatQueryParser {
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else { } else {
if (parseContext.enbaleLLM()) { if (!parseContext.isDisableLLM()) {
parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(), parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(), text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(),
parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig( parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig(

View File

@@ -21,13 +21,6 @@ public class ParseContext {
if (agent == null) { if (agent == null) {
return false; return false;
} }
return agent.containsNL2SQLTool(); return agent.containsDatasetTool();
}
public boolean enbaleLLM() {
if (agent == null || disableLLM) {
return false;
}
return agent.containsLLMTool();
} }
} }

View File

@@ -105,7 +105,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
} }
private synchronized void doExecuteAgentExamples(Agent agent) { private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsLLMTool() if (!agent.containsDatasetTool()
|| !ModelConfigHelper.testConnection( || !ModelConfigHelper.testConnection(
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL)) ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL))
|| CollectionUtils.isEmpty(agent.getExamples())) { || CollectionUtils.isEmpty(agent.getExamples())) {

View File

@@ -26,22 +26,10 @@ public class QueryReqConverter {
return queryNLReq; return queryNLReq;
} }
ChatModelConfig chatModelConfig =
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL);
boolean hasLLMTool = agent.containsLLMTool();
boolean hasRuleTool = agent.containsRuleTool();
boolean hasLLMConfig = chatModelConfig != null;
if (parseContext.isDisableLLM()) { if (parseContext.isDisableLLM()) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
} else if (hasLLMTool && hasLLMConfig) { } else {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (hasLLMTool && hasRuleTool) {
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
} else if (hasLLMTool) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (hasRuleTool) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
} }
queryNLReq.setDataSetIds(agent.getDataSetIds()); queryNLReq.setDataSetIds(agent.getDataSetIds());
@@ -49,6 +37,8 @@ public class QueryReqConverter {
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) { && MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
queryNLReq.setMapInfo(queryNLReq.getMapInfo()); queryNLReq.setMapInfo(queryNLReq.getMapInfo());
} }
ChatModelConfig chatModelConfig =
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL);
queryNLReq.setModelConfig(chatModelConfig); queryNLReq.setModelConfig(chatModelConfig);
queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate()); queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate());
if (chatCtx != null) { if (chatCtx != null) {

View File

@@ -4,7 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool; import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.ModelRela;
@@ -335,11 +335,11 @@ public class DuSQLDemo extends S2BaseDemo {
agent.setExamples(Lists.newArrayList()); agent.setExamples(Lists.newArrayList());
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();
LLMParserTool llmParserTool = new LLMParserTool(); DatasetTool datasetTool = new DatasetTool();
llmParserTool.setId("1"); datasetTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM); datasetTool.setType(AgentToolType.DATASET);
llmParserTool.setDataSetIds(Lists.newArrayList(4L)); datasetTool.setDataSetIds(Lists.newArrayList(4L));
toolConfig.getTools().add(llmParserTool); toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
log.info("agent:{}", JsonUtil.toString(agent)); log.info("agent:{}", JsonUtil.toString(agent));

View File

@@ -4,8 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool; import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
@@ -169,19 +168,11 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setEnableSearch(1); agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派")); agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool(); DatasetTool datasetTool = new DatasetTool();
ruleQueryTool.setId("0"); datasetTool.setId("1");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); datasetTool.setType(AgentToolType.DATASET);
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId)); datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(ruleQueryTool); toolConfig.getTools().add(datasetTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(llmParserTool);
}
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
agentService.createAgent(agent, defaultUser); agentService.createAgent(agent, defaultUser);
} }

View File

@@ -8,9 +8,8 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool; import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
@@ -151,18 +150,12 @@ public class S2VisitsDemo extends S2BaseDemo {
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长")); "过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
// configure tools // configure tools
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool(); DatasetTool datasetTool = new DatasetTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); datasetTool.setId("1");
ruleQueryTool.setId("0"); datasetTool.setType(AgentToolType.DATASET);
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId)); datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(ruleQueryTool); toolConfig.getTools().add(datasetTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(llmParserTool);
}
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat models // configure chat models
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap(); Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();

View File

@@ -148,11 +148,11 @@ public class Text2SQLEval extends BaseTest {
return agent; return agent;
} }
private static LLMParserTool getLLMQueryTool() { private static DatasetTool getLLMQueryTool() {
LLMParserTool llmParserTool = new LLMParserTool(); DatasetTool datasetTool = new DatasetTool();
llmParserTool.setType(AgentToolType.NL2SQL_LLM); datasetTool.setType(AgentToolType.DATASET);
llmParserTool.setDataSetIds(Lists.newArrayList(-1L)); datasetTool.setDataSetIds(Lists.newArrayList(-1L));
return llmParserTool; return datasetTool;
} }
} }

View File

@@ -42,6 +42,7 @@ public class DataUtils {
chatParseReq.setQueryText(query); chatParseReq.setQueryText(query);
chatParseReq.setChatId(id); chatParseReq.setChatId(id);
chatParseReq.setUser(user_test); chatParseReq.setUser(user_test);
chatParseReq.setDisableLLM(true);
return chatParseReq; return chatParseReq;
} }