mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(fix)(chat&headless)Agent references ChatModelConfig instead of ModelConfig.
This commit is contained in:
@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.server.agent;
|
|||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
import com.tencent.supersonic.common.config.VisualConfig;
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
@@ -34,7 +34,7 @@ public class Agent extends RecordInfo {
|
|||||||
private Integer status;
|
private Integer status;
|
||||||
private List<String> examples;
|
private List<String> examples;
|
||||||
private String agentConfig;
|
private String agentConfig;
|
||||||
private ModelConfig modelConfig;
|
private ChatModelConfig modelConfig;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
private MultiTurnConfig multiTurnConfig;
|
private MultiTurnConfig multiTurnConfig;
|
||||||
private VisualConfig visualConfig;
|
private VisualConfig visualConfig;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
|||||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||||
import org.springframework.web.bind.annotation.PathVariable;
|
import org.springframework.web.bind.annotation.PathVariable;
|
||||||
@@ -51,7 +51,7 @@ public class AgentController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("/testLLMConn")
|
@PostMapping("/testLLMConn")
|
||||||
public boolean testLLMConn(@RequestBody ModelConfig modelConfig) {
|
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
|
||||||
return LLMConnHelper.testConnection(modelConfig);
|
return LLMConnHelper.testConnection(modelConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
|||||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
import com.tencent.supersonic.common.config.VisualConfig;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
@@ -122,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
|||||||
BeanUtils.copyProperties(agentDO, agent);
|
BeanUtils.copyProperties(agentDO, agent);
|
||||||
agent.setAgentConfig(agentDO.getConfig());
|
agent.setAgentConfig(agentDO.getConfig());
|
||||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||||
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ModelConfig.class));
|
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class));
|
||||||
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||||
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.server.util;
|
package com.tencent.supersonic.chat.server.util;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.provider.ModelProvider;
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
@@ -9,10 +9,9 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LLMConnHelper {
|
public class LLMConnHelper {
|
||||||
public static boolean testConnection(ModelConfig modelConfig) {
|
public static boolean testConnection(ChatModelConfig modelConfig) {
|
||||||
try {
|
try {
|
||||||
if (modelConfig == null || modelConfig.getChatModel() == null
|
if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||||
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
||||||
|
|||||||
@@ -18,11 +18,18 @@ public class QueryReqConverter {
|
|||||||
if (agent == null) {
|
if (agent == null) {
|
||||||
return queryNLReq;
|
return queryNLReq;
|
||||||
}
|
}
|
||||||
if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
|
|
||||||
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
boolean hasLLMTool = agent.containsLLMParserTool();
|
||||||
} else if (agent.containsLLMParserTool()) {
|
boolean hasRuleTool = agent.containsRuleTool();
|
||||||
|
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());
|
||||||
|
|
||||||
|
if (hasLLMTool && hasLLMConfig) {
|
||||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||||
} else if (agent.containsRuleTool()) {
|
} else if (hasLLMTool && hasRuleTool) {
|
||||||
|
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||||
|
} else if (hasLLMTool) {
|
||||||
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||||
|
} else if (hasRuleTool) {
|
||||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||||
}
|
}
|
||||||
queryNLReq.setDataSetIds(agent.getDataSetIds());
|
queryNLReq.setDataSetIds(agent.getDataSetIds());
|
||||||
|
|||||||
@@ -19,19 +19,18 @@ public class ModelProvider {
|
|||||||
factories.put(provider, modelFactory);
|
factories.put(provider, modelFactory);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ChatLanguageModel getChatModel(ModelConfig modelConfig) {
|
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
|
||||||
if (modelConfig == null || modelConfig.getChatModel() == null
|
if (modelConfig == null
|
||||||
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|
|| StringUtils.isBlank(modelConfig.getProvider())
|
||||||
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
|
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
}
|
}
|
||||||
ChatModelConfig chatModel = modelConfig.getChatModel();
|
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
|
||||||
ModelFactory modelFactory = factories.get(chatModel.getProvider().toUpperCase());
|
|
||||||
if (modelFactory != null) {
|
if (modelFactory != null) {
|
||||||
return modelFactory.createChatModel(chatModel);
|
return modelFactory.createChatModel(modelConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
|
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {
|
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
@@ -27,7 +27,7 @@ public class QueryNLReq {
|
|||||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||||
private ModelConfig modelConfig;
|
private ChatModelConfig modelConfig;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
|
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -51,7 +51,7 @@ public class ChatQueryContext {
|
|||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
private ChatWorkflowState chatWorkflowState;
|
private ChatWorkflowState chatWorkflowState;
|
||||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||||
private ModelConfig modelConfig;
|
private ChatModelConfig modelConfig;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
private List<SqlExemplar> dynamicExemplars;
|
private List<SqlExemplar> dynamicExemplars;
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -23,7 +23,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
|
|||||||
@Autowired
|
@Autowired
|
||||||
protected PromptHelper promptHelper;
|
protected PromptHelper promptHelper;
|
||||||
|
|
||||||
protected ChatLanguageModel getChatLanguageModel(ModelConfig modelConfig) {
|
protected ChatLanguageModel getChatLanguageModel(ChatModelConfig modelConfig) {
|
||||||
return ModelProvider.getChatModel(modelConfig);
|
return ModelProvider.getChatModel(modelConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonValue;
|
import com.fasterxml.jackson.annotation.JsonValue;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
@@ -27,7 +27,7 @@ public class LLMReq {
|
|||||||
|
|
||||||
private SqlGenType sqlGenType;
|
private SqlGenType sqlGenType;
|
||||||
|
|
||||||
private ModelConfig modelConfig;
|
private ChatModelConfig modelConfig;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
|
|
||||||
private List<SqlExemplar> dynamicExemplars;
|
private List<SqlExemplar> dynamicExemplars;
|
||||||
|
|||||||
@@ -141,7 +141,6 @@ public class S2VisitsDemo extends S2BaseDemo {
|
|||||||
chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计");
|
chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计");
|
||||||
chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天");
|
chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天");
|
||||||
chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
|
chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
|
||||||
chatService.parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy访问次数");
|
|
||||||
chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
|
chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
|||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
@@ -109,7 +108,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
GLM
|
GLM
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ModelConfig getLLMConfig(LLMType type) {
|
private static ChatModelConfig getLLMConfig(LLMType type) {
|
||||||
String baseUrl;
|
String baseUrl;
|
||||||
String apiKey;
|
String apiKey;
|
||||||
String modelName;
|
String modelName;
|
||||||
@@ -151,9 +150,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
chatModel.setTemperature(temperature);
|
chatModel.setTemperature(temperature);
|
||||||
chatModel.setProvider("open_ai");
|
chatModel.setProvider("open_ai");
|
||||||
|
|
||||||
ModelConfig modelConfig = new ModelConfig();
|
return chatModel;
|
||||||
modelConfig.setChatModel(chatModel);
|
|
||||||
return modelConfig;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user