diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 1715b04d1..cdb38d81a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.server.agent; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.pojo.RecordInfo; @@ -34,7 +34,7 @@ public class Agent extends RecordInfo { private Integer status; private List examples; private String agentConfig; - private ModelConfig modelConfig; + private ChatModelConfig modelConfig; private PromptConfig promptConfig; private MultiTurnConfig multiTurnConfig; private VisualConfig visualConfig; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java index 179bc28f7..5ef9f9867 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.util.LLMConnHelper; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.config.ChatModelConfig; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.PathVariable; @@ -51,7 +51,7 @@ public class AgentController { } @PostMapping("/testLLMConn") - public boolean testLLMConn(@RequestBody ModelConfig modelConfig) { + public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) { return LLMConnHelper.testConnection(modelConfig); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 737953230..0a6b019b0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.util.LLMConnHelper; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.util.JsonUtil; @@ -122,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl BeanUtils.copyProperties(agentDO, agent); agent.setAgentConfig(agentDO.getConfig()); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); - agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ModelConfig.class)); + agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class)); agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class)); agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java index 50202ae50..40c8de89f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.chat.server.util; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.provider.ModelProvider; @@ -9,10 +9,9 @@ import org.apache.commons.lang3.StringUtils; @Slf4j public class LLMConnHelper { - public static boolean testConnection(ModelConfig modelConfig) { + public static boolean testConnection(ChatModelConfig modelConfig) { try { - if (modelConfig == null || modelConfig.getChatModel() == null - || StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) { + if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) { return false; } ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 62b674b19..deb3ad0e6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -18,11 +18,18 @@ public class QueryReqConverter { if (agent == null) { return queryNLReq; } - if (agent.containsLLMParserTool() && agent.containsRuleTool()) { - queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); - } else if (agent.containsLLMParserTool()) { + + boolean hasLLMTool = agent.containsLLMParserTool(); + boolean hasRuleTool = agent.containsRuleTool(); + boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig()); + + if (hasLLMTool && hasLLMConfig) { queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM); - } else if (agent.containsRuleTool()) { + } else if (hasLLMTool && hasRuleTool) { + queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); + } else if (hasLLMTool) { + queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM); + } else if (hasRuleTool) { queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); } queryNLReq.setDataSetIds(agent.getDataSetIds()); diff --git a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java index cb0d7332a..de0715506 100644 --- a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java +++ b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java @@ -19,19 +19,18 @@ public class ModelProvider { factories.put(provider, modelFactory); } - public static ChatLanguageModel getChatModel(ModelConfig modelConfig) { - if (modelConfig == null || modelConfig.getChatModel() == null - || StringUtils.isBlank(modelConfig.getChatModel().getProvider()) - || StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) { + public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) { + if (modelConfig == null + || StringUtils.isBlank(modelConfig.getProvider()) + || StringUtils.isBlank(modelConfig.getBaseUrl())) { return ContextUtils.getBean(ChatLanguageModel.class); } - ChatModelConfig chatModel = modelConfig.getChatModel(); - ModelFactory modelFactory = factories.get(chatModel.getProvider().toUpperCase()); + ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase()); if (modelFactory != null) { - return modelFactory.createChatModel(chatModel); + return modelFactory.createChatModel(modelConfig); } - throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider()); + throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider()); } public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java index af2824f91..df0e98ce8 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; @@ -27,7 +27,7 @@ public class QueryNLReq { private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private SchemaMapInfo mapInfo = new SchemaMapInfo(); private QueryDataType queryDataType = QueryDataType.ALL; - private ModelConfig modelConfig; + private ChatModelConfig modelConfig; private PromptConfig promptConfig; private List dynamicExemplars = Lists.newArrayList(); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index 127a42e3d..f4f911426 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat; import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; @@ -51,7 +51,7 @@ public class ChatQueryContext { @JsonIgnore private ChatWorkflowState chatWorkflowState; private QueryDataType queryDataType = QueryDataType.ALL; - private ModelConfig modelConfig; + private ChatModelConfig modelConfig; private PromptConfig promptConfig; private List dynamicExemplars; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java index 795ec6c73..d761ca955 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.headless.chat.parser.llm; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -23,7 +23,7 @@ public abstract class SqlGenStrategy implements InitializingBean { @Autowired protected PromptHelper promptHelper; - protected ChatLanguageModel getChatLanguageModel(ModelConfig modelConfig) { + protected ChatLanguageModel getChatLanguageModel(ChatModelConfig modelConfig) { return ModelProvider.getChatModel(modelConfig); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index 65d155030..d9fe0b1de 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.Lists; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.api.pojo.SchemaElement; @@ -27,7 +27,7 @@ public class LLMReq { private SqlGenType sqlGenType; - private ModelConfig modelConfig; + private ChatModelConfig modelConfig; private PromptConfig promptConfig; private List dynamicExemplars; diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 561b130f5..32b18b2d4 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -141,7 +141,6 @@ public class S2VisitsDemo extends S2BaseDemo { chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计"); chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天"); chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长"); - chatService.parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy访问次数"); chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门"); } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 2d2ff9090..60db2d439 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.util.DataUtils; import org.junit.jupiter.api.BeforeAll; @@ -109,7 +108,7 @@ public class Text2SQLEval extends BaseTest { GLM } - private static ModelConfig getLLMConfig(LLMType type) { + private static ChatModelConfig getLLMConfig(LLMType type) { String baseUrl; String apiKey; String modelName; @@ -151,9 +150,7 @@ public class Text2SQLEval extends BaseTest { chatModel.setTemperature(temperature); chatModel.setProvider("open_ai"); - ModelConfig modelConfig = new ModelConfig(); - modelConfig.setChatModel(chatModel); - return modelConfig; + return chatModel; } }