[feature][chat]Introduce dedicated LLM management.#1739

This commit is contained in:
jerryjzhang
2024-10-09 12:00:24 +08:00
parent 0654a54c8d
commit 7b9ff2e281
22 changed files with 367 additions and 43 deletions

View File

@@ -5,17 +5,12 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
@Data
@@ -29,10 +24,9 @@ public class Agent extends RecordInfo {
/** 0 offline, 1 online */
private Integer status;
private List<String> examples;
private String agentConfig;
private ChatModelConfig modelConfig;
private Map<ChatModelType, Integer> modelConfig = Collections.EMPTY_MAP;
private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig;

View File

@@ -8,6 +8,8 @@ import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import dev.langchain4j.data.message.AiMessage;
@@ -44,8 +46,8 @@ public class PlainTextExecutor implements ChatQueryExecutor {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatAgent.getModelConfig());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.RESPONSE_GENERATE));
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult();

View File

@@ -5,6 +5,8 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
@@ -63,8 +65,8 @@ public class MemoryReviewTask {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatAgent.getModelConfig());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.MEMORY_REVIEW));
if (Objects.nonNull(chatLanguageModel)) {
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);

View File

@@ -1,16 +1,19 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -52,7 +55,7 @@ public class NL2SQLParser implements ChatQueryParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
private static final String REWRITE_MULTI_TURN_INSTRUCTION = ""
+ "#Role: You are a data product manager experienced in data requirements."
+ "#Task: Your will be provided with current and history questions asked by a user,"
+ "along with their mapped schema elements(metric, dimension and value),"
@@ -96,8 +99,8 @@ public class NL2SQLParser implements ChatQueryParser {
if (parseContext.enbaleLLM()) {
parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(),
parseContext.getAgent().getExamples(),
parseContext.getAgent().getModelConfig()));
parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig(
parseContext.getAgent(), ChatModelType.RESPONSE_GENERATE)));
}
}
parseResp.setState(text2SqlParseResp.getState());
@@ -158,8 +161,9 @@ public class NL2SQLParser implements ChatQueryParser {
}
private void processMultiTurn(ParseContext parseContext) {
Agent agent = parseContext.getAgent();
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
MultiTurnConfig agentMultiTurnConfig = agent.getMultiTurnConfig();
Boolean globalMultiTurnConfig =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
@@ -170,8 +174,8 @@ public class NL2SQLParser implements ChatQueryParser {
return;
}
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.MULTI_TURN_REWRITE));
// derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
@@ -198,12 +202,12 @@ public class NL2SQLParser implements ChatQueryParser {
variables.put("history_schema", histMapStr);
variables.put("history_sql", histSQL);
Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
Prompt prompt = PromptTemplate.from(REWRITE_MULTI_TURN_INSTRUCTION).apply(variables);
keyPipelineLog.info("QueryRewrite reqPrompt:{}", prompt.text());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenQuery = response.content().text();
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
keyPipelineLog.info("QueryRewrite modelResp:{}", rewrittenQuery);
parseContext.setQueryText(rewrittenQuery);
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq);
@@ -226,12 +230,12 @@ public class NL2SQLParser implements ChatQueryParser {
variables.put("examples", exampleStr);
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
keyPipelineLog.info("ErrorRewrite reqPrompt:{}", prompt.text());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenMsg = response.content().text();
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenMsg);
keyPipelineLog.info("ErrorRewrite modelResp:{}", rewrittenMsg);
return rewrittenMsg;
}

View File

@@ -45,7 +45,9 @@ public class AgentDO {
private Integer enableSearch;
private Integer enableMemoryReview;
private String modelConfig;
private String multiTurnConfig;
private String visualConfig;

View File

@@ -0,0 +1,33 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import java.util.Date;
@Data
@TableName("s2_chat_model")
public class ChatModelDO {
@TableId(type = IdType.AUTO)
private Integer id;
private String name;
private String description;
private String config;
private Date createdAt;
private String createdBy;
private Date updatedAt;
private String updatedBy;
private String admin;
private String viewer;
}

View File

@@ -0,0 +1,9 @@
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatModelDO;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatModelMapper extends BaseMapper<ChatModelDO> {
}

View File

@@ -0,0 +1,29 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import lombok.Data;
import java.util.Date;
@Data
public class ChatModel {
private Integer id;
private String name;
private String description;
private ChatModelConfig Config;
private Date createdAt;
private String createdBy;
private Date updatedAt;
private String updatedBy;
private String admin;
private String viewer;
}

View File

@@ -8,7 +8,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
@@ -61,6 +61,6 @@ public class AgentController {
@PostMapping("/testLLMConn")
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
return LLMConnHelper.testConnection(modelConfig);
return ModelConfigHelper.testConnection(modelConfig);
}
}

View File

@@ -0,0 +1,52 @@
package com.tencent.supersonic.chat.server.rest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@RestController
@RequestMapping({"/api/chat/model", "/openapi/chat/model"})
public class ChatModelController {
@Autowired
private ChatModelService chatModelService;
@PostMapping
public ChatModel createModel(@RequestBody ChatModel model,
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return chatModelService.createChatModel(model, user);
}
@PutMapping
public ChatModel updateModel(@RequestBody ChatModel model,
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return chatModelService.updateChatModel(model, user);
}
@DeleteMapping("/{id}")
public boolean deleteModel(@PathVariable("id") Integer id) {
chatModelService.deleteChatModel(id);
return true;
}
@RequestMapping("/getModelList")
public List<ChatModel> getModelList() {
return chatModelService.getChatModels();
}
@PostMapping("/testConnection")
public boolean testConnection(@RequestBody ChatModelConfig modelConfig) {
return ModelConfigHelper.testConnection(modelConfig);
}
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import java.util.List;
public interface ChatModelService {
List<ChatModel> getChatModels();
ChatModel getChatModel(Integer id);
ChatModel createChatModel(ChatModel chatModel, User user);
ChatModel updateChatModel(ChatModel chatModel, User user);
void deleteChatModel(Integer id);
}

View File

@@ -10,12 +10,13 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy;
import lombok.extern.slf4j.Slf4j;
@@ -40,6 +41,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
@Autowired
private ChatQueryService chatQueryService;
@Autowired
private ChatModelService chatModelService;
private ExecutorService executorService = Executors.newFixedThreadPool(1);
@Override
@@ -101,7 +105,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
}
private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsLLMTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
if (!agent.containsLLMTool()
|| !ModelConfigHelper.testConnection(
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL))
|| CollectionUtils.isEmpty(agent.getExamples())) {
return;
}
@@ -136,7 +142,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class));
agent.setModelConfig(
JsonUtil.toMap(agentDO.getModelConfig(), ChatModelType.class, Integer.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));

View File

@@ -0,0 +1,89 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatModelDO;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatModelMapper;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.StringUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Slf4j
@Service
public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModelDO>
implements ChatModelService {
@Override
public List<ChatModel> getChatModels() {
return list().stream().map(this::convert).collect(Collectors.toList());
}
@Override
public ChatModel getChatModel(Integer id) {
if (id == null) {
return null;
}
return convert(getById(id));
}
@Override
public ChatModel createChatModel(ChatModel chatModel, User user) {
ChatModelDO chatModelDO = convert(chatModel);
chatModelDO.setCreatedBy(user.getName());
chatModelDO.setCreatedAt(new Date());
chatModelDO.setUpdatedBy(user.getName());
chatModelDO.setUpdatedAt(new Date());
if (StringUtils.isBlank(chatModel.getAdmin())) {
chatModelDO.setAdmin(user.getName());
}
save(chatModelDO);
return chatModel;
}
@Override
public ChatModel updateChatModel(ChatModel chatModel, User user) {
ChatModelDO chatModelDO = convert(chatModel);
chatModelDO.setUpdatedBy(user.getName());
chatModelDO.setUpdatedAt(new Date());
if (StringUtils.isBlank(chatModel.getAdmin())) {
chatModel.setAdmin(user.getName());
}
updateById(chatModelDO);
return chatModel;
}
@Override
public void deleteChatModel(Integer id) {
removeById(id);
}
private ChatModel convert(ChatModelDO chatModelDO) {
if (chatModelDO == null) {
return null;
}
ChatModel chatModel = new ChatModel();
BeanUtils.copyProperties(chatModelDO, chatModel);
chatModel.setConfig(JsonUtil.toObject(chatModelDO.getConfig(), ChatModelConfig.class));
return chatModel;
}
private ChatModelDO convert(ChatModel chatModel) {
if (chatModel == null) {
return null;
}
ChatModelDO chatModelDO = new ChatModelDO();
BeanUtils.copyProperties(chatModel, chatModelDO);
chatModelDO.setConfig(JsonUtil.toString(chatModel.getConfig()));
return chatModelDO;
}
}

View File

@@ -1,14 +1,18 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class LLMConnHelper {
public class ModelConfigHelper {
public static boolean testConnection(ChatModelConfig modelConfig) {
try {
if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) {
@@ -22,4 +26,14 @@ public class LLMConnHelper {
throw new InvalidArgumentException(e.getMessage());
}
}
public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) {
ChatModelConfig chatModelConfig = null;
if (agent.getModelConfig().containsKey(modelType)) {
Integer chatModelId = agent.getModelConfig().get(modelType);
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig();
}
return chatModelConfig;
}
}

View File

@@ -3,6 +3,8 @@ package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
@@ -24,9 +26,11 @@ public class QueryReqConverter {
return queryNLReq;
}
ChatModelConfig chatModelConfig =
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL);
boolean hasLLMTool = agent.containsLLMTool();
boolean hasRuleTool = agent.containsRuleTool();
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());
boolean hasLLMConfig = chatModelConfig != null;
if (parseContext.isDisableLLM()) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
@@ -45,7 +49,7 @@ public class QueryReqConverter {
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
}
queryNLReq.setModelConfig(agent.getModelConfig());
queryNLReq.setModelConfig(chatModelConfig);
queryNLReq.setPromptConfig(agent.getPromptConfig());
if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());