diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 8605d7db5..9b3711097 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -5,17 +5,12 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.VisualConfig; -import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.RecordInfo; +import com.tencent.supersonic.common.pojo.enums.ChatModelType; import lombok.Data; import org.springframework.util.CollectionUtils; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; @Data @@ -29,10 +24,9 @@ public class Agent extends RecordInfo { /** 0 offline, 1 online */ private Integer status; - private List examples; private String agentConfig; - private ChatModelConfig modelConfig; + private Map modelConfig = Collections.EMPTY_MAP; private PromptConfig promptConfig; private MultiTurnConfig multiTurnConfig; private VisualConfig visualConfig; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index 2b1f0030b..446bcf450 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -8,6 +8,8 @@ import com.tencent.supersonic.chat.server.parser.ParserConfig; import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatManageService; +import com.tencent.supersonic.chat.server.util.ModelConfigHelper; +import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import dev.langchain4j.data.message.AiMessage; @@ -44,8 +46,8 @@ public class PlainTextExecutor implements ChatQueryExecutor { AgentService agentService = ContextUtils.getBean(AgentService.class); Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId()); - ChatLanguageModel chatLanguageModel = - ModelProvider.getChatModel(chatAgent.getModelConfig()); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( + ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.RESPONSE_GENERATE)); Response response = chatLanguageModel.generate(prompt.toUserMessage()); QueryResult result = new QueryResult(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index 76d237e7a..de225e886 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -5,6 +5,8 @@ import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.MemoryService; +import com.tencent.supersonic.chat.server.util.ModelConfigHelper; +import com.tencent.supersonic.common.pojo.enums.ChatModelType; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; @@ -63,8 +65,8 @@ public class MemoryReviewTask { Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr); - ChatLanguageModel chatLanguageModel = - ModelProvider.getChatModel(chatAgent.getModelConfig()); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( + ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.MEMORY_REVIEW)); if (Objects.nonNull(chatLanguageModel)) { String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index c1ba0a369..57faf19fc 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,16 +1,19 @@ package com.tencent.supersonic.chat.server.parser; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; +import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.service.ChatContextService; import com.tencent.supersonic.chat.server.service.ChatManageService; +import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; +import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; @@ -52,7 +55,7 @@ public class NL2SQLParser implements ChatQueryParser { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - private static final String REWRITE_USER_QUESTION_INSTRUCTION = "" + private static final String REWRITE_MULTI_TURN_INSTRUCTION = "" + "#Role: You are a data product manager experienced in data requirements." + "#Task: Your will be provided with current and history questions asked by a user," + "along with their mapped schema elements(metric, dimension and value)," @@ -96,8 +99,8 @@ public class NL2SQLParser implements ChatQueryParser { if (parseContext.enbaleLLM()) { parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(), text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(), - parseContext.getAgent().getExamples(), - parseContext.getAgent().getModelConfig())); + parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig( + parseContext.getAgent(), ChatModelType.RESPONSE_GENERATE))); } } parseResp.setState(text2SqlParseResp.getState()); @@ -158,8 +161,9 @@ public class NL2SQLParser implements ChatQueryParser { } private void processMultiTurn(ParseContext parseContext) { + Agent agent = parseContext.getAgent(); ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); - MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig(); + MultiTurnConfig agentMultiTurnConfig = agent.getMultiTurnConfig(); Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); @@ -170,8 +174,8 @@ public class NL2SQLParser implements ChatQueryParser { return; } - ChatLanguageModel chatLanguageModel = - ModelProvider.getChatModel(parseContext.getAgent().getModelConfig()); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( + ModelConfigHelper.getChatModelConfig(agent, ChatModelType.MULTI_TURN_REWRITE)); // derive mapping result of current question and parsing result of last question. ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); @@ -198,12 +202,12 @@ public class NL2SQLParser implements ChatQueryParser { variables.put("history_schema", histMapStr); variables.put("history_sql", histSQL); - Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables); - keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text()); + Prompt prompt = PromptTemplate.from(REWRITE_MULTI_TURN_INSTRUCTION).apply(variables); + keyPipelineLog.info("QueryRewrite reqPrompt:{}", prompt.text()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); String rewrittenQuery = response.content().text(); - keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery); + keyPipelineLog.info("QueryRewrite modelResp:{}", rewrittenQuery); parseContext.setQueryText(rewrittenQuery); QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq); @@ -226,12 +230,12 @@ public class NL2SQLParser implements ChatQueryParser { variables.put("examples", exampleStr); Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables); - keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text()); + keyPipelineLog.info("ErrorRewrite reqPrompt:{}", prompt.text()); ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig); Response response = chatLanguageModel.generate(prompt.toUserMessage()); String rewrittenMsg = response.content().text(); - keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenMsg); + keyPipelineLog.info("ErrorRewrite modelResp:{}", rewrittenMsg); return rewrittenMsg; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java index e9b03ead1..1b927482a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java @@ -45,7 +45,9 @@ public class AgentDO { private Integer enableSearch; private Integer enableMemoryReview; + private String modelConfig; + private String multiTurnConfig; private String visualConfig; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatModelDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatModelDO.java new file mode 100644 index 000000000..d02daf1a8 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatModelDO.java @@ -0,0 +1,33 @@ +package com.tencent.supersonic.chat.server.persistence.dataobject; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + +import java.util.Date; + +@Data +@TableName("s2_chat_model") +public class ChatModelDO { + @TableId(type = IdType.AUTO) + private Integer id; + + private String name; + + private String description; + + private String config; + + private Date createdAt; + + private String createdBy; + + private Date updatedAt; + + private String updatedBy; + + private String admin; + + private String viewer; +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatModelMapper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatModelMapper.java new file mode 100644 index 000000000..6ef44da89 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatModelMapper.java @@ -0,0 +1,9 @@ +package com.tencent.supersonic.chat.server.persistence.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatModelDO; +import org.apache.ibatis.annotations.Mapper; + +@Mapper +public interface ChatModelMapper extends BaseMapper { +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatModel.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatModel.java new file mode 100644 index 000000000..cc5eaf517 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatModel.java @@ -0,0 +1,29 @@ +package com.tencent.supersonic.chat.server.pojo; + +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import lombok.Data; + +import java.util.Date; + +@Data +public class ChatModel { + private Integer id; + + private String name; + + private String description; + + private ChatModelConfig Config; + + private Date createdAt; + + private String createdBy; + + private Date updatedAt; + + private String updatedBy; + + private String admin; + + private String viewer; +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java index c5490d0a8..fc6218086 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java @@ -8,7 +8,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.service.AgentService; -import com.tencent.supersonic.chat.server.util.LLMConnHelper; +import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.common.pojo.ChatModelConfig; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.DeleteMapping; @@ -61,6 +61,6 @@ public class AgentController { @PostMapping("/testLLMConn") public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) { - return LLMConnHelper.testConnection(modelConfig); + return ModelConfigHelper.testConnection(modelConfig); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java new file mode 100644 index 000000000..1f6ab2a8d --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java @@ -0,0 +1,52 @@ +package com.tencent.supersonic.chat.server.rest; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.chat.server.pojo.ChatModel; +import com.tencent.supersonic.chat.server.service.ChatModelService; +import com.tencent.supersonic.chat.server.util.ModelConfigHelper; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.*; + +import java.util.List; + +@RestController +@RequestMapping({"/api/chat/model", "/openapi/chat/model"}) +public class ChatModelController { + @Autowired + private ChatModelService chatModelService; + + @PostMapping + public ChatModel createModel(@RequestBody ChatModel model, + HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { + User user = UserHolder.findUser(httpServletRequest, httpServletResponse); + return chatModelService.createChatModel(model, user); + } + + @PutMapping + public ChatModel updateModel(@RequestBody ChatModel model, + HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { + User user = UserHolder.findUser(httpServletRequest, httpServletResponse); + return chatModelService.updateChatModel(model, user); + } + + @DeleteMapping("/{id}") + public boolean deleteModel(@PathVariable("id") Integer id) { + chatModelService.deleteChatModel(id); + return true; + } + + @RequestMapping("/getModelList") + public List getModelList() { + return chatModelService.getChatModels(); + } + + @PostMapping("/testConnection") + public boolean testConnection(@RequestBody ChatModelConfig modelConfig) { + return ModelConfigHelper.testConnection(modelConfig); + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatModelService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatModelService.java new file mode 100644 index 000000000..ef8046a9a --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatModelService.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.chat.server.service; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.server.pojo.ChatModel; + +import java.util.List; + +public interface ChatModelService { + List getChatModels(); + + ChatModel getChatModel(Integer id); + + ChatModel createChatModel(ChatModel chatModel, User user); + + ChatModel updateChatModel(ChatModel chatModel, User user); + + void deleteChatModel(Integer id); +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 646622409..cfdfefdd0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -10,12 +10,13 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper; import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.MemoryService; -import com.tencent.supersonic.chat.server.util.LLMConnHelper; +import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.VisualConfig; -import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy; import lombok.extern.slf4j.Slf4j; @@ -40,6 +41,9 @@ public class AgentServiceImpl extends ServiceImpl implem @Autowired private ChatQueryService chatQueryService; + @Autowired + private ChatModelService chatModelService; + private ExecutorService executorService = Executors.newFixedThreadPool(1); @Override @@ -101,7 +105,9 @@ public class AgentServiceImpl extends ServiceImpl implem } private synchronized void doExecuteAgentExamples(Agent agent) { - if (!agent.containsLLMTool() || !LLMConnHelper.testConnection(agent.getModelConfig()) + if (!agent.containsLLMTool() + || !ModelConfigHelper.testConnection( + ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL)) || CollectionUtils.isEmpty(agent.getExamples())) { return; } @@ -136,7 +142,8 @@ public class AgentServiceImpl extends ServiceImpl implem BeanUtils.copyProperties(agentDO, agent); agent.setAgentConfig(agentDO.getConfig()); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); - agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class)); + agent.setModelConfig( + JsonUtil.toMap(agentDO.getModelConfig(), ChatModelType.class, Integer.class)); agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class)); agent.setMultiTurnConfig( JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatModelServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatModelServiceImpl.java new file mode 100644 index 000000000..962cef617 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatModelServiceImpl.java @@ -0,0 +1,89 @@ +package com.tencent.supersonic.chat.server.service.impl; + +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatModelDO; +import com.tencent.supersonic.chat.server.persistence.mapper.ChatModelMapper; +import com.tencent.supersonic.chat.server.pojo.ChatModel; +import com.tencent.supersonic.chat.server.service.ChatModelService; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.common.util.StringUtil; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang.StringUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.stereotype.Service; + +import java.util.Date; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +@Slf4j +@Service +public class ChatModelServiceImpl extends ServiceImpl + implements ChatModelService { + @Override + public List getChatModels() { + return list().stream().map(this::convert).collect(Collectors.toList()); + } + + @Override + public ChatModel getChatModel(Integer id) { + if (id == null) { + return null; + } + return convert(getById(id)); + } + + @Override + public ChatModel createChatModel(ChatModel chatModel, User user) { + ChatModelDO chatModelDO = convert(chatModel); + chatModelDO.setCreatedBy(user.getName()); + chatModelDO.setCreatedAt(new Date()); + chatModelDO.setUpdatedBy(user.getName()); + chatModelDO.setUpdatedAt(new Date()); + if (StringUtils.isBlank(chatModel.getAdmin())) { + chatModelDO.setAdmin(user.getName()); + } + save(chatModelDO); + return chatModel; + } + + @Override + public ChatModel updateChatModel(ChatModel chatModel, User user) { + ChatModelDO chatModelDO = convert(chatModel); + chatModelDO.setUpdatedBy(user.getName()); + chatModelDO.setUpdatedAt(new Date()); + if (StringUtils.isBlank(chatModel.getAdmin())) { + chatModel.setAdmin(user.getName()); + } + updateById(chatModelDO); + return chatModel; + } + + @Override + public void deleteChatModel(Integer id) { + removeById(id); + } + + private ChatModel convert(ChatModelDO chatModelDO) { + if (chatModelDO == null) { + return null; + } + ChatModel chatModel = new ChatModel(); + BeanUtils.copyProperties(chatModelDO, chatModel); + chatModel.setConfig(JsonUtil.toObject(chatModelDO.getConfig(), ChatModelConfig.class)); + return chatModel; + } + + private ChatModelDO convert(ChatModel chatModel) { + if (chatModel == null) { + return null; + } + ChatModelDO chatModelDO = new ChatModelDO(); + BeanUtils.copyProperties(chatModel, chatModelDO); + chatModelDO.setConfig(JsonUtil.toString(chatModel.getConfig())); + return chatModelDO; + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java similarity index 55% rename from chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java rename to chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java index e0506079d..f94ac4eb5 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java @@ -1,14 +1,18 @@ package com.tencent.supersonic.chat.server.util; +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; +import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.provider.ModelProvider; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @Slf4j -public class LLMConnHelper { +public class ModelConfigHelper { public static boolean testConnection(ChatModelConfig modelConfig) { try { if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) { @@ -22,4 +26,14 @@ public class LLMConnHelper { throw new InvalidArgumentException(e.getMessage()); } } + + public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) { + ChatModelConfig chatModelConfig = null; + if (agent.getModelConfig().containsKey(modelType)) { + Integer chatModelId = agent.getModelConfig().get(modelType); + ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class); + chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig(); + } + return chatModelConfig; + } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 7f01c0472..90842862c 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -3,6 +3,8 @@ package com.tencent.supersonic.chat.server.util; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ParseContext; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; @@ -24,9 +26,11 @@ public class QueryReqConverter { return queryNLReq; } + ChatModelConfig chatModelConfig = + ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL); boolean hasLLMTool = agent.containsLLMTool(); boolean hasRuleTool = agent.containsRuleTool(); - boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig()); + boolean hasLLMConfig = chatModelConfig != null; if (parseContext.isDisableLLM()) { queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); @@ -45,7 +49,7 @@ public class QueryReqConverter { && MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) { queryNLReq.setMapInfo(queryNLReq.getMapInfo()); } - queryNLReq.setModelConfig(agent.getModelConfig()); + queryNLReq.setModelConfig(chatModelConfig); queryNLReq.setPromptConfig(agent.getPromptConfig()); if (chatCtx != null) { queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ChatModelType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ChatModelType.java new file mode 100644 index 000000000..e817e9b3a --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ChatModelType.java @@ -0,0 +1,14 @@ +package com.tencent.supersonic.common.pojo.enums; + +public enum ChatModelType { + TEXT_TO_SQL("Convert text query to SQL statement"), MULTI_TURN_REWRITE( + "Rewrite text query for multi-turn conversation"), MEMORY_REVIEW( + "Review memory in order to add few-shot examples"), RESPONSE_GENERATE( + "Generate readable response to the end user"); + + private String description; + + ChatModelType(String description) { + this.description = description; + } +} diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 1f73a9962..994032fd8 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -155,13 +155,8 @@ public class S2VisitsDemo extends S2BaseDemo { agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计"); agent.setStatus(1); agent.setEnableSearch(1); - agent.setExamples(Lists.newArrayList( - "近15天超音数访问次数汇总", - "按部门统计超音数的访问人数", - "对比alice和lucy的停留时长", - "过去30天访问次数最高的部门top3", - "近1个月总访问次数超过100次的部门有几个", - "过去半个月每个核心用户的总停留时长")); + agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长", + "过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长")); AgentConfig agentConfig = new AgentConfig(); RuleParserTool ruleQueryTool = new RuleParserTool(); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index a22df3112..31fe9d33d 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -103,6 +103,21 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( ) ; COMMENT ON TABLE s2_chat_memory IS 'chat memory table '; +CREATE TABLE IF NOT EXISTS `s2_chat_model` +( + id int AUTO_INCREMENT, + name varchar(100) null, + description varchar(500) null, + `config` varchar(500) NOT NULL , + `created_at` TIMESTAMP NOT NULL , + `created_by` varchar(100) NOT NULL , + `updated_at` TIMESTAMP NOT NULL , + `updated_by` varchar(100) NOT NULL, + `admin` varchar(500) NOT NULL, + `viewer` varchar(500) DEFAULT NULL, + PRIMARY KEY (`id`) +); COMMENT ON TABLE s2_chat_model IS 'chat model table'; + create table IF NOT EXISTS s2_user ( id INT AUTO_INCREMENT, @@ -388,7 +403,6 @@ CREATE TABLE IF NOT EXISTS s2_agent PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_agent IS 'agent information table'; - -------demo for semantic and chat CREATE TABLE IF NOT EXISTS `s2_user_department` ( `user_name` varchar(200) NOT NULL, diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index 9a8be0676..df550617b 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -210,6 +210,20 @@ CREATE TABLE IF NOT EXISTS `s2_chat_statistics` ( KEY `commonIndex` (`question_id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; +CREATE TABLE IF NOT EXISTS `s2_chat_model` ( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `name` varchar(255) NOT NULL COMMENT '名称', + `description` varchar(500) DEFAULT NULL COMMENT '描述', + `config` text NOT NULL COMMENT '配置信息', + `created_at` datetime NOT NULL COMMENT '创建时间', + `created_by` varchar(100) NOT NULL COMMENT '创建人', + `updated_at` datetime NOT NULL COMMENT '更新时间', + `updated_by` varchar(100) NOT NULL COMMENT '更新人', + `admin` varchar(500) DEFAULT NULL, + `viewer` varchar(500) DEFAULT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='对话大模型实例表'; + CREATE TABLE IF NOT EXISTS `s2_database` ( `id` bigint(20) NOT NULL AUTO_INCREMENT, `name` varchar(255) NOT NULL COMMENT '名称', diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java index 9740e10de..f6a587291 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.headless.api.pojo.SchemaElement; @@ -33,6 +34,8 @@ public class BaseTest extends BaseApplication { protected ChatQueryService chatQueryService; @Autowired protected AgentService agentService; + @Autowired + protected ChatModelService chatModelService; protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index c7b022a84..736beeb5d 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -2,6 +2,8 @@ package com.tencent.supersonic.evaluation; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.BaseTest; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.server.agent.Agent; @@ -9,12 +11,14 @@ import com.tencent.supersonic.chat.server.agent.AgentConfig; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.RuleParserTool; +import com.tencent.supersonic.chat.server.pojo.ChatModel; +import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.LLMConfigUtils; import org.junit.jupiter.api.*; import java.util.List; -import java.util.stream.Collectors; +import java.util.Map; @TestInstance(TestInstance.Lifecycle.PER_CLASS) @Disabled @@ -133,7 +137,13 @@ public class Text2SQLEval extends BaseTest { AgentConfig agentConfig = new AgentConfig(); agentConfig.getTools().add(getLLMQueryTool()); agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); - agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3)); + ChatModel chatModel = new ChatModel(); + chatModel.setName("Text2SQL LLM"); + chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3)); + chatModel = chatModelService.createChatModel(chatModel, User.getFakeUser()); + Map chatModelConfig = Maps.newHashMap(); + chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId()); + agent.setModelConfig(chatModelConfig); MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); multiTurnConfig.setEnableMultiTurn(enableMultiturn); agent.setMultiTurnConfig(multiTurnConfig); diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index a22df3112..bd8638399 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -103,6 +103,21 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( ) ; COMMENT ON TABLE s2_chat_memory IS 'chat memory table '; +CREATE TABLE IF NOT EXISTS `s2_chat_model` +( + id int AUTO_INCREMENT, + name varchar(100) null, + description varchar(500) null, + `config` varchar(500) NOT NULL , + `created_at` TIMESTAMP NOT NULL , + `created_by` varchar(100) NOT NULL , + `updated_at` TIMESTAMP NOT NULL , + `updated_by` varchar(100) NOT NULL, + `admin` varchar(500) NOT NULL, + `viewer` varchar(500) DEFAULT NULL, + PRIMARY KEY (`id`) +); COMMENT ON TABLE s2_chat_model IS 'chat model table'; + create table IF NOT EXISTS s2_user ( id INT AUTO_INCREMENT,