(fix)(chat&headless)Agent references ChatModelConfig instead of ModelConfig.

This commit is contained in:
jerryjzhang
2024-07-12 11:58:38 +08:00
parent 37da1ac2ae
commit 5bf4a4160d
12 changed files with 37 additions and 36 deletions

View File

@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
@@ -34,7 +34,7 @@ public class Agent extends RecordInfo {
private Integer status; private Integer status;
private List<String> examples; private List<String> examples;
private String agentConfig; private String agentConfig;
private ModelConfig modelConfig; private ChatModelConfig modelConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig; private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig; private VisualConfig visualConfig;

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
@@ -51,7 +51,7 @@ public class AgentController {
} }
@PostMapping("/testLLMConn") @PostMapping("/testLLMConn")
public boolean testLLMConn(@RequestBody ModelConfig modelConfig) { public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
return LLMConnHelper.testConnection(modelConfig); return LLMConnHelper.testConnection(modelConfig);
} }

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
@@ -122,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
BeanUtils.copyProperties(agentDO, agent); BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig()); agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ModelConfig.class)); agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class)); agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.util; package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.provider.ModelProvider; import dev.langchain4j.provider.ModelProvider;
@@ -9,10 +9,9 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j @Slf4j
public class LLMConnHelper { public class LLMConnHelper {
public static boolean testConnection(ModelConfig modelConfig) { public static boolean testConnection(ChatModelConfig modelConfig) {
try { try {
if (modelConfig == null || modelConfig.getChatModel() == null if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) {
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
return false; return false;
} }
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig); ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);

View File

@@ -18,11 +18,18 @@ public class QueryReqConverter {
if (agent == null) { if (agent == null) {
return queryNLReq; return queryNLReq;
} }
if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); boolean hasLLMTool = agent.containsLLMParserTool();
} else if (agent.containsLLMParserTool()) { boolean hasRuleTool = agent.containsRuleTool();
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());
if (hasLLMTool && hasLLMConfig) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM); queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (agent.containsRuleTool()) { } else if (hasLLMTool && hasRuleTool) {
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
} else if (hasLLMTool) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (hasRuleTool) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
} }
queryNLReq.setDataSetIds(agent.getDataSetIds()); queryNLReq.setDataSetIds(agent.getDataSetIds());

View File

@@ -19,19 +19,18 @@ public class ModelProvider {
factories.put(provider, modelFactory); factories.put(provider, modelFactory);
} }
public static ChatLanguageModel getChatModel(ModelConfig modelConfig) { public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
if (modelConfig == null || modelConfig.getChatModel() == null if (modelConfig == null
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider()) || StringUtils.isBlank(modelConfig.getProvider())
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) { || StringUtils.isBlank(modelConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class); return ContextUtils.getBean(ChatLanguageModel.class);
} }
ChatModelConfig chatModel = modelConfig.getChatModel(); ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
ModelFactory modelFactory = factories.get(chatModel.getProvider().toUpperCase());
if (modelFactory != null) { if (modelFactory != null) {
return modelFactory.createChatModel(chatModel); return modelFactory.createChatModel(modelConfig);
} }
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider()); throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
} }
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) { public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -27,7 +27,7 @@ public class QueryNLReq {
private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private ModelConfig modelConfig; private ChatModelConfig modelConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList(); private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
} }

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -51,7 +51,7 @@ public class ChatQueryContext {
@JsonIgnore @JsonIgnore
private ChatWorkflowState chatWorkflowState; private ChatWorkflowState chatWorkflowState;
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private ModelConfig modelConfig; private ChatModelConfig modelConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars; private List<SqlExemplar> dynamicExemplars;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -23,7 +23,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
@Autowired @Autowired
protected PromptHelper promptHelper; protected PromptHelper promptHelper;
protected ChatLanguageModel getChatLanguageModel(ModelConfig modelConfig) { protected ChatLanguageModel getChatLanguageModel(ChatModelConfig modelConfig) {
return ModelProvider.getChatModel(modelConfig); return ModelProvider.getChatModel(modelConfig);
} }

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -27,7 +27,7 @@ public class LLMReq {
private SqlGenType sqlGenType; private SqlGenType sqlGenType;
private ModelConfig modelConfig; private ChatModelConfig modelConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars; private List<SqlExemplar> dynamicExemplars;

View File

@@ -141,7 +141,6 @@ public class S2VisitsDemo extends S2BaseDemo {
chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计"); chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计");
chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天"); chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天");
chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长"); chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
chatService.parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy访问次数");
chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门"); chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
} }

View File

@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
@@ -109,7 +108,7 @@ public class Text2SQLEval extends BaseTest {
GLM GLM
} }
private static ModelConfig getLLMConfig(LLMType type) { private static ChatModelConfig getLLMConfig(LLMType type) {
String baseUrl; String baseUrl;
String apiKey; String apiKey;
String modelName; String modelName;
@@ -151,9 +150,7 @@ public class Text2SQLEval extends BaseTest {
chatModel.setTemperature(temperature); chatModel.setTemperature(temperature);
chatModel.setProvider("open_ai"); chatModel.setProvider("open_ai");
ModelConfig modelConfig = new ModelConfig(); return chatModel;
modelConfig.setChatModel(chatModel);
return modelConfig;
} }
} }