[improvement][chat]Merge RuleTool and LLMTool.#1766 (#1768)

This commit is contained in:
Jun Zhang
2024-10-10 14:53:54 +08:00
committed by GitHub
parent fe3f4c36b5
commit 85552fb73a
14 changed files with 44 additions and 114 deletions

View File

@@ -54,12 +54,12 @@ public class Agent extends RecordInfo {
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
}
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
public List<DatasetTool> getParserTools(AgentToolType agentToolType) {
List<String> tools = this.getTools(agentToolType);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class))
.collect(Collectors.toList());
}
@@ -67,17 +67,8 @@ public class Agent extends RecordInfo {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
}
public boolean containsLLMTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
}
public boolean containsRuleTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
}
public boolean containsNL2SQLTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM))
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
public boolean containsDatasetTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.DATASET));
}
public boolean containsAnyTool() {
@@ -102,11 +93,11 @@ public class Agent extends RecordInfo {
}
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
List<DatasetTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
}
return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds)
return commonAgentTools.stream().map(DatasetTool::getDataSetIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream)
.collect(Collectors.toSet());
}

View File

@@ -4,9 +4,9 @@ import java.util.HashMap;
import java.util.Map;
public enum AgentToolType {
NL2SQL_RULE("基于规则Text-to-SQL"), NL2SQL_LLM("基于大模型Text-to-SQL"), PLUGIN("第三方插件");
DATASET("Text2SQL数据集"), PLUGIN("第三方插件");
private String title;
private final String title;
AgentToolType(String title) {
this.title = title;
@@ -14,8 +14,7 @@ public enum AgentToolType {
public static Map<AgentToolType, String> getToolTypes() {
Map<AgentToolType, String> map = new HashMap<>();
map.put(NL2SQL_RULE, NL2SQL_RULE.title);
map.put(NL2SQL_LLM, NL2SQL_LLM.title);
map.put(DATASET, DATASET.title);
map.put(PLUGIN, PLUGIN.title);
return map;
}

View File

@@ -9,7 +9,8 @@ import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class NL2SQLTool extends AgentTool {
public class DatasetTool extends AgentTool {
protected List<Long> dataSetIds;
private List<Long> dataSetIds;
private List<String> exampleQuestions;
}

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.Data;
import java.util.List;
@Data
public class LLMParserTool extends NL2SQLTool {
private List<String> exampleQuestions;
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import java.util.List;
@Data
public class RuleParserTool extends NL2SQLTool {
private List<String> queryModes;
private List<String> queryTypes;
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L);
}
}

View File

@@ -85,7 +85,7 @@ public class NL2SQLParser implements ChatQueryParser {
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
if (parseContext.enbaleLLM()) {
if (!parseContext.isDisableLLM()) {
processMultiTurn(parseContext);
}
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
@@ -96,7 +96,7 @@ public class NL2SQLParser implements ChatQueryParser {
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else {
if (parseContext.enbaleLLM()) {
if (!parseContext.isDisableLLM()) {
parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(),
parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig(

View File

@@ -21,13 +21,6 @@ public class ParseContext {
if (agent == null) {
return false;
}
return agent.containsNL2SQLTool();
}
public boolean enbaleLLM() {
if (agent == null || disableLLM) {
return false;
}
return agent.containsLLMTool();
return agent.containsDatasetTool();
}
}

View File

@@ -105,7 +105,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
}
private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsLLMTool()
if (!agent.containsDatasetTool()
|| !ModelConfigHelper.testConnection(
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL))
|| CollectionUtils.isEmpty(agent.getExamples())) {

View File

@@ -26,22 +26,10 @@ public class QueryReqConverter {
return queryNLReq;
}
ChatModelConfig chatModelConfig =
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL);
boolean hasLLMTool = agent.containsLLMTool();
boolean hasRuleTool = agent.containsRuleTool();
boolean hasLLMConfig = chatModelConfig != null;
if (parseContext.isDisableLLM()) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
} else if (hasLLMTool && hasLLMConfig) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (hasLLMTool && hasRuleTool) {
} else {
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
} else if (hasLLMTool) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (hasRuleTool) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
}
queryNLReq.setDataSetIds(agent.getDataSetIds());
@@ -49,6 +37,8 @@ public class QueryReqConverter {
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
}
ChatModelConfig chatModelConfig =
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL);
queryNLReq.setModelConfig(chatModelConfig);
queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate());
if (chatCtx != null) {

View File

@@ -4,7 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
@@ -335,11 +335,11 @@ public class DuSQLDemo extends S2BaseDemo {
agent.setExamples(Lists.newArrayList());
ToolConfig toolConfig = new ToolConfig();
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(4L));
toolConfig.getTools().add(llmParserTool);
DatasetTool datasetTool = new DatasetTool();
datasetTool.setId("1");
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(4L));
toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
log.info("agent:{}", JsonUtil.toString(agent));

View File

@@ -4,8 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
@@ -169,19 +168,11 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(llmParserTool);
}
DatasetTool datasetTool = new DatasetTool();
datasetTool.setId("1");
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
agentService.createAgent(agent, defaultUser);
}

View File

@@ -8,9 +8,8 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
@@ -151,18 +150,12 @@ public class S2VisitsDemo extends S2BaseDemo {
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
// configure tools
ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(llmParserTool);
}
DatasetTool datasetTool = new DatasetTool();
datasetTool.setId("1");
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat models
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();

View File

@@ -148,11 +148,11 @@ public class Text2SQLEval extends BaseTest {
return agent;
}
private static LLMParserTool getLLMQueryTool() {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(-1L));
private static DatasetTool getLLMQueryTool() {
DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(-1L));
return llmParserTool;
return datasetTool;
}
}

View File

@@ -42,6 +42,7 @@ public class DataUtils {
chatParseReq.setQueryText(query);
chatParseReq.setChatId(id);
chatParseReq.setUser(user_test);
chatParseReq.setDisableLLM(true);
return chatParseReq;
}