diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 6d73b1d01..942757c94 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.agent; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.enums.ChatModelType; import lombok.Data; @@ -24,6 +25,7 @@ public class Agent extends RecordInfo { private Integer enableMemoryReview; private String toolConfig; private Map chatModelConfig = Collections.EMPTY_MAP; + private Map chatAppConfig = Collections.EMPTY_MAP; private PromptConfig promptConfig; private MultiTurnConfig multiTurnConfig; private VisualConfig visualConfig; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index 446bcf450..a567c83c2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -8,8 +8,8 @@ import com.tencent.supersonic.chat.server.parser.ParserConfig; import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatManageService; -import com.tencent.supersonic.chat.server.util.ModelConfigHelper; -import com.tencent.supersonic.common.pojo.enums.ChatModelType; +import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import dev.langchain4j.data.message.AiMessage; @@ -28,26 +28,35 @@ import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULT public class PlainTextExecutor implements ChatQueryExecutor { + private static final String APP_KEY = "SMALL_TALK"; private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to.\n" + "#Task: Respond quickly and nicely to the user." + "#Rules: 1.ALWAYS use the same language as the input.\n" + "#History Inputs: %s\n" + "#Current Input: %s\n" + "#Your response: "; + public PlainTextExecutor() { + ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("闲聊对话") + .description("直接将原始输入透传大模型").enable(true).build()); + } + @Override public QueryResult execute(ExecuteContext executeContext) { if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) { return null; } - String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext), - executeContext.getQueryText()); - Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); - AgentService agentService = ContextUtils.getBean(AgentService.class); Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId()); + ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY); + if (!chatApp.isEnable()) { + return null; + } - ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( - ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.RESPONSE_GENERATE)); + String promptStr = String.format(chatApp.getPrompt(), getHistoryInputs(executeContext), + executeContext.getQueryText()); + Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); + ChatLanguageModel chatLanguageModel = + ModelProvider.getChatModel(chatApp.getChatModelConfig()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); QueryResult result = new QueryResult(); @@ -60,25 +69,12 @@ public class PlainTextExecutor implements ChatQueryExecutor { private String getHistoryInputs(ExecuteContext executeContext) { StringBuilder historyInput = new StringBuilder(); + List queryResps = getHistoryQueries(executeContext.getChatId(), 5); + queryResps.stream().forEach(p -> { + historyInput.append(p.getQueryText()); + historyInput.append(";"); - AgentService agentService = ContextUtils.getBean(AgentService.class); - Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId()); - - ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); - MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig(); - Boolean globalMultiTurnConfig = - Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); - Boolean multiTurnConfig = - agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() - : globalMultiTurnConfig; - - if (Boolean.TRUE.equals(multiTurnConfig)) { - List queryResps = getHistoryQueries(executeContext.getChatId(), 5); - queryResps.stream().forEach(p -> { - historyInput.append(p.getQueryText()); - historyInput.append(";"); - }); - } + }); return historyInput.toString(); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index 8014a9256..80e1df0c6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -6,7 +6,8 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.util.ModelConfigHelper; -import com.tencent.supersonic.common.pojo.enums.ChatModelType; +import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.util.ChatAppManager; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; @@ -29,11 +30,11 @@ public class MemoryReviewTask { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); + public static final String APP_KEY = "MEMORY_REVIEW"; private static final String INSTRUCTION = "" + "\n#Role: You are a senior data engineer experienced in writing SQL." + "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer," - + "please take a review and give your opinion." - + "\n#Rules: " + + "please take a review and give your opinion." + "\n#Rules: " + "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." + "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard." + "\n#Question: %s" + "\n#Schema: %s" + "\n#SideInfo: %s" + "\n#SQL: %s" @@ -47,6 +48,11 @@ public class MemoryReviewTask { @Autowired private AgentService agentService; + public MemoryReviewTask() { + ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("记忆启用评估") + .description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build()); + } + @Scheduled(fixedDelay = 60 * 1000) public void review() { try { @@ -58,16 +64,22 @@ public class MemoryReviewTask { private void processMemory(ChatMemoryDO m) { Agent chatAgent = agentService.getAgent(m.getAgentId()); - if (Objects.isNull(chatAgent) || !chatAgent.enableMemoryReview()) { - log.debug("Agent id {} not found or memory review disabled", m.getAgentId()); + if (Objects.isNull(chatAgent)) { + log.warn("Agent id {} not found or memory review disabled", m.getAgentId()); return; } - String promptStr = createPromptString(m); + + ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY); + if (!chatApp.isEnable()) { + return; + } + + String promptStr = createPromptString(m, chatApp.getPrompt()); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr); - ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( - ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.MEMORY_REVIEW)); + ChatLanguageModel chatLanguageModel = + ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp)); if (Objects.nonNull(chatLanguageModel)) { String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response); @@ -77,8 +89,8 @@ public class MemoryReviewTask { } } - private String createPromptString(ChatMemoryDO m) { - return String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), + private String createPromptString(ChatMemoryDO m, String promptTemplate) { + return String.format(promptTemplate, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), m.getS2sql()); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 91af7493d..31cf46a61 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,8 +1,6 @@ package com.tencent.supersonic.chat.server.parser; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; -import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ParseContext; @@ -11,10 +9,10 @@ import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; -import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; +import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; @@ -47,7 +45,6 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; -import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; @Slf4j @@ -55,6 +52,7 @@ public class NL2SQLParser implements ChatQueryParser { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); + public static final String APP_KEY_MULTI_TURN = "REWRITE_MULTI_TURN"; private static final String REWRITE_MULTI_TURN_INSTRUCTION = "" + "#Role: You are a data product manager experienced in data requirements." + "#Task: Your will be provided with current and history questions asked by a user," @@ -68,6 +66,7 @@ public class NL2SQLParser implements ChatQueryParser { + "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}" + "#Rewritten Question: "; + public static final String APP_KEY_ERROR_MESSAGE = "REWRITE_ERROR_MESSAGE"; private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = "" + "#Role: You are a data business partner who closely interacts with business people.\n" + "#Task: Your will be provided with user input, system output and some examples, " @@ -77,6 +76,16 @@ public class NL2SQLParser implements ChatQueryParser { + "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n" + "#Examples: {{examples}}\n" + "#Response: "; + public NL2SQLParser() { + ChatAppManager.register( + ChatApp.builder().key(APP_KEY_MULTI_TURN).prompt(REWRITE_MULTI_TURN_INSTRUCTION) + .name("多轮对话改写").description("通过大模型根据历史对话来改写本轮对话").enable(false).build()); + + ChatAppManager.register(ChatApp.builder().key(APP_KEY_ERROR_MESSAGE) + .prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写") + .description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build()); + } + @Override public void parse(ParseContext parseContext, ParseResp parseResp) { if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) { @@ -97,10 +106,8 @@ public class NL2SQLParser implements ChatQueryParser { parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); } else { if (!parseContext.isDisableLLM()) { - parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(), - text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(), - parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig( - parseContext.getAgent(), ChatModelType.RESPONSE_GENERATE))); + parseResp.setErrorMsg(rewriteErrorMessage(parseContext, + text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars())); } } parseResp.setState(text2SqlParseResp.getState()); @@ -162,22 +169,11 @@ public class NL2SQLParser implements ChatQueryParser { } private void processMultiTurn(ParseContext parseContext) { - Agent agent = parseContext.getAgent(); - ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); - MultiTurnConfig agentMultiTurnConfig = agent.getMultiTurnConfig(); - Boolean globalMultiTurnConfig = - Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); - - Boolean multiTurnConfig = - agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() - : globalMultiTurnConfig; - if (!Boolean.TRUE.equals(multiTurnConfig)) { + ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_MULTI_TURN); + if (!chatApp.isEnable()) { return; } - ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( - ModelConfigHelper.getChatModelConfig(agent, ChatModelType.MULTI_TURN_REWRITE)); - // derive mapping result of current question and parsing result of last question. ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); @@ -203,9 +199,11 @@ public class NL2SQLParser implements ChatQueryParser { variables.put("history_schema", histMapStr); variables.put("history_sql", histSQL); - Prompt prompt = PromptTemplate.from(REWRITE_MULTI_TURN_INSTRUCTION).apply(variables); + Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables); keyPipelineLog.info("QueryRewrite reqPrompt:{}", prompt.text()); + ChatLanguageModel chatLanguageModel = + ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp)); Response response = chatLanguageModel.generate(prompt.toUserMessage()); String rewrittenQuery = response.content().text(); keyPipelineLog.info("QueryRewrite modelResp:{}", rewrittenQuery); @@ -217,24 +215,30 @@ public class NL2SQLParser implements ChatQueryParser { currentMapResult.getQueryText(), rewrittenQuery); } - private String rewriteErrorMessage(String userQuestion, String errMsg, - List similarExemplars, List agentExamples, - ChatModelConfig modelConfig) { + private String rewriteErrorMessage(ParseContext parseContext, String errMsg, + List similarExemplars) { + + ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE); + if (!chatApp.isEnable()) { + return errMsg; + } + Map variables = new HashMap<>(); - variables.put("user_question", userQuestion); + variables.put("user_question", parseContext.getQueryText()); variables.put("system_message", errMsg); StringBuilder exampleStr = new StringBuilder(); similarExemplars.forEach(e -> exampleStr.append( String.format(" ", e.getQuestion(), e.getDbSchema()))); - agentExamples.forEach(e -> exampleStr.append(String.format(" ", e))); + parseContext.getAgent().getExamples() + .forEach(e -> exampleStr.append(String.format(" ", e))); variables.put("examples", exampleStr); - Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables); + Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables); keyPipelineLog.info("ErrorRewrite reqPrompt:{}", prompt.text()); - ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig); + ChatLanguageModel chatLanguageModel = + ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp)); Response response = chatLanguageModel.generate(prompt.toUserMessage()); - String rewrittenMsg = response.content().text(); keyPipelineLog.info("ErrorRewrite modelResp:{}", rewrittenMsg); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java index 0644b80c9..b2a12705b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java @@ -33,15 +33,9 @@ public class AgentDO { private Integer enableSearch; - private Integer enableMemoryReview; - private String toolConfig; private String chatModelConfig; - private String multiTurnConfig; - private String visualConfig; - - private String promptConfig; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java index a711a7c69..cd0db1912 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java @@ -10,12 +10,15 @@ import com.tencent.supersonic.chat.server.config.ChatModelParameters; import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.util.ModelConfigHelper; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.enums.ChatModelType; +import com.tencent.supersonic.common.util.ChatAppManager; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -51,11 +54,9 @@ public class ChatModelController { return chatModelService.getChatModels(); } - @RequestMapping("/getModelTypeList") - public List getModelTypeList() { - return Arrays.stream(ChatModelType.values()).map(t -> ChatModelTypeResp.builder() - .type(t.toString()).name(t.getName()).description(t.getDescription()).build()) - .collect(Collectors.toList()); + @RequestMapping("/getModelAppList") + public List getModelAppList() { + return new ArrayList(ChatAppManager.getAllApps().values()); } @RequestMapping("/getModelParameters") diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index a183a94e0..f7776007a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -15,8 +15,8 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.MemoryService; -import com.tencent.supersonic.chat.server.util.ModelConfigHelper; -import com.tencent.supersonic.common.pojo.enums.ChatModelType; +import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy; import lombok.extern.slf4j.Slf4j; @@ -26,6 +26,7 @@ import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -53,12 +54,6 @@ public class AgentServiceImpl extends ServiceImpl implem @Override public Agent createAgent(Agent agent, User user) { - if (Objects.isNull(agent.getPromptConfig()) - || Objects.isNull(agent.getPromptConfig().getPromptTemplate())) { - PromptConfig promptConfig = new PromptConfig(); - promptConfig.setPromptTemplate(OnePassSCSqlGenStrategy.INSTRUCTION.trim()); - agent.setPromptConfig(promptConfig); - } agent.createdBy(user.getName()); AgentDO agentDO = convert(agent); save(agentDO); @@ -69,12 +64,6 @@ public class AgentServiceImpl extends ServiceImpl implem @Override public Agent updateAgent(Agent agent, User user) { - if (Objects.isNull(agent.getPromptConfig()) - || Objects.isNull(agent.getPromptConfig().getPromptTemplate())) { - PromptConfig promptConfig = new PromptConfig(); - promptConfig.setPromptTemplate(OnePassSCSqlGenStrategy.INSTRUCTION.trim()); - agent.setPromptConfig(promptConfig); - } agent.updatedBy(user.getName()); updateById(convert(agent)); executeAgentExamplesAsync(agent); @@ -105,10 +94,7 @@ public class AgentServiceImpl extends ServiceImpl implem } private synchronized void doExecuteAgentExamples(Agent agent) { - if (!agent.containsDatasetTool() - || !agent.enableMemoryReview() - || !ModelConfigHelper.testConnection( - ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL)) + if (!agent.containsDatasetTool() || !agent.enableMemoryReview() || CollectionUtils.isEmpty(agent.getExamples())) { return; } @@ -144,11 +130,8 @@ public class AgentServiceImpl extends ServiceImpl implem BeanUtils.copyProperties(agentDO, agent); agent.setToolConfig(agentDO.getToolConfig()); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); - agent.setChatModelConfig( - JsonUtil.toMap(agentDO.getChatModelConfig(), ChatModelType.class, Integer.class)); - agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class)); - agent.setMultiTurnConfig( - JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); + agent.setChatAppConfig( + JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class)); agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); return agent; } @@ -158,10 +141,8 @@ public class AgentServiceImpl extends ServiceImpl implem BeanUtils.copyProperties(agent, agentDO); agentDO.setToolConfig(agent.getToolConfig()); agentDO.setExamples(JsonUtil.toString(agent.getExamples())); - agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatModelConfig())); - agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig())); + agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatAppConfig())); agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig())); - agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig())); if (agentDO.getStatus() == null) { agentDO.setStatus(1); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java index 2148cc338..250726e29 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.util; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.service.ChatModelService; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; @@ -27,13 +28,10 @@ public class ModelConfigHelper { } } - public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) { - ChatModelConfig chatModelConfig = null; - if (agent.getChatModelConfig().containsKey(modelType)) { - Integer chatModelId = agent.getChatModelConfig().get(modelType); - ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class); - chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig(); - } + public static ChatModelConfig getChatModelConfig(ChatApp chatApp) { + ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class); + ChatModelConfig chatModelConfig = + chatModelService.getChatModel(chatApp.getChatModelId()).getConfig(); return chatModelConfig; } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 026fc14fa..7a8a0c31f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -3,10 +3,12 @@ package com.tencent.supersonic.chat.server.util; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ParseContext; +import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.BeanMapper; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import org.apache.commons.collections.MapUtils; @@ -37,10 +39,7 @@ public class QueryReqConverter { && MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) { queryNLReq.setMapInfo(queryNLReq.getMapInfo()); } - ChatModelConfig chatModelConfig = - ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL); - queryNLReq.setModelConfig(chatModelConfig); - queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate()); + queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig()); if (chatCtx != null) { queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java new file mode 100644 index 000000000..b56f7733c --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java @@ -0,0 +1,22 @@ +package com.tencent.supersonic.common.pojo; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class ChatApp { + private String key; + private String name; + private String description; + private String prompt; + private boolean enable; + private Integer chatModelId; + @JsonIgnore + private ChatModelConfig chatModelConfig; +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/ChatAppManager.java b/common/src/main/java/com/tencent/supersonic/common/util/ChatAppManager.java new file mode 100644 index 000000000..ba2919925 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/ChatAppManager.java @@ -0,0 +1,23 @@ +package com.tencent.supersonic.common.util; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.tencent.supersonic.common.pojo.ChatApp; + +import java.util.List; +import java.util.Map; + +public class ChatAppManager { + private static final Map chatApps = Maps.newConcurrentMap(); + + public static void register(ChatApp chatApp) { + if (chatApps.containsKey(chatApp.getKey())) { + throw new RuntimeException("Duplicate chat app key is disallowed."); + } + chatApps.put(chatApp.getKey(), chatApp); + } + + public static Map getAllApps() { + return chatApps; + } +} diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java index aa6c1c483..43884b04e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; @@ -13,6 +14,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import lombok.Data; import java.util.List; +import java.util.Map; import java.util.Set; @Data @@ -26,8 +28,7 @@ public class QueryNLReq { private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private SchemaMapInfo mapInfo = new SchemaMapInfo(); private QueryDataType queryDataType = QueryDataType.ALL; - private ChatModelConfig modelConfig; - private String customPrompt; + private Map chatAppConfig; private List dynamicExemplars = Lists.newArrayList(); private SemanticParseInfo contextParseInfo; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index e6aae57a9..3bdbe0c8f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat; import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; @@ -47,13 +48,14 @@ public class ChatQueryContext { private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SemanticParseInfo contextParseInfo; private MapModeEnum mapModeEnum = MapModeEnum.STRICT; + private QueryDataType queryDataType = QueryDataType.ALL; @JsonIgnore private SemanticSchema semanticSchema; @JsonIgnore private ChatWorkflowState chatWorkflowState; - private QueryDataType queryDataType = QueryDataType.ALL; - private ChatModelConfig modelConfig; - private String customPrompt; + @JsonIgnore + private Map chatAppConfig; + @JsonIgnore private List dynamicExemplars; public List getCandidateQueries() { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java index 31d1467e3..448a3c75e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java @@ -1,5 +1,7 @@ package com.tencent.supersonic.headless.chat.corrector; +import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -23,11 +25,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - public static final String INSTRUCTION = "" + public static final String APP_KEY = "S2SQL_CORRECTOR"; + private static final String INSTRUCTION = "" + "\n#Role: You are a senior data engineer experienced in writing SQL." + "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer," - + "please take a review and help correct it if necessary." - + "\n#Rules: " + + "please take a review and help correct it if necessary." + "\n#Rules: " + "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`." + "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard." + "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." @@ -36,6 +38,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { + "\n6.ALWAYS translate alias created by `AS` command to the same language as the `#Question`." + "\n#Question:{{question}} #InputSQL:{{sql}} #Response:"; + public LLMSqlCorrector() { + ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("语义SQL修正") + .description("").enable(false).build()); + } + @Data @ToString static class SemanticSql { @@ -52,14 +59,16 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - if (!chatQueryContext.getText2SQLType().enableLLM()) { + ChatApp chatApp = chatQueryContext.getChatAppConfig().get(APP_KEY); + if (!chatQueryContext.getText2SQLType().enableLLM() || !chatApp.isEnable()) { return; } - ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatQueryContext.getModelConfig()); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig()); SemanticSqlExtractor extractor = AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); - Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo); + Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo, + chatApp.getPrompt()); keyPipelineLog.info("LLMSqlCorrector reqPrompt:\n{}", prompt.text()); SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText()); keyPipelineLog.info("LLMSqlCorrector modelResp:\n{}", s2Sql); @@ -68,12 +77,12 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { } } - private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo) { + private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo, + String promptTemplate) { Map variable = new HashMap<>(); variable.put("question", queryText); variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); - String promptTemplate = INSTRUCTION; return PromptTemplate.from(promptTemplate).apply(variable); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 0801aedd2..232121014 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -74,8 +74,7 @@ public class LLMRequestService { llmReq.setTerms(getMappedTerms(queryCtx, dataSetId)); llmReq.setSqlGenType( LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); - llmReq.setModelConfig(queryCtx.getModelConfig()); - llmReq.setCustomPrompt(queryCtx.getCustomPrompt()); + llmReq.setChatAppConfig(queryCtx.getChatAppConfig()); llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars()); return llmReq; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java index 556ca59eb..375292419 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.chat.parser.llm; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.SemanticParser; @@ -69,10 +70,12 @@ public class LLMSqlParser implements SemanticParser { } catch (Exception e) { log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e); } - Double temperature = llmReq.getModelConfig().getTemperature(); + ChatModelConfig chatModelConfig = llmReq.getChatAppConfig() + .get(OnePassSCSqlGenStrategy.APP_KEY).getChatModelConfig(); + Double temperature = chatModelConfig.getTemperature(); if (temperature == 0) { // 报错时增加随机性,减少无效重试 - llmReq.getModelConfig().setTemperature(0.5); + chatModelConfig.setTemperature(0.5); } currentRetry++; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index a5c8d82ae..2e04a1a28 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -1,7 +1,9 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; +import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -11,7 +13,6 @@ import dev.langchain4j.model.output.structured.Description; import dev.langchain4j.service.AiServices; import lombok.Data; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.stereotype.Service; @@ -24,6 +25,7 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + public static final String APP_KEY = "S2SQL_PARSER"; public static final String INSTRUCTION = "" + "\n#Role: You are a data analyst experienced in SQL languages." + "\n#Task: You will be provided with a natural language question asked by users," @@ -40,6 +42,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "\n#Exemplars: {{exemplar}}" + "\n#Question: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + public OnePassSCSqlGenStrategy() { + ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("语义SQL解析") + .description("通过大模型做语义解析生成S2SQL").enable(true).build()); + } + @Data static class SemanticSql { @Description("thought or remarks to tell users about the sql, make it short.") @@ -62,15 +69,17 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { List> exemplarsList = promptHelper.getFewShotExemplars(llmReq); // 2.generate sql generation prompt for each self-consistency inference + ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY); + ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatApp.getChatModelConfig()); + SemanticSqlExtractor extractor = + AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); + Map> prompt2Exemplar = new HashMap<>(); for (List exemplars : exemplarsList) { llmReq.setDynamicExemplars(exemplars); - Prompt prompt = generatePrompt(llmReq, llmResp); + Prompt prompt = generatePrompt(llmReq, llmResp, chatApp); prompt2Exemplar.put(prompt, exemplars); } - ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig()); - SemanticSqlExtractor extractor = - AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); // 3.perform multiple self-consistency inferences parallelly Map output2Prompt = new ConcurrentHashMap<>(); @@ -92,7 +101,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { return llmResp; } - private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) { + private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp, ChatApp chatApp) { StringBuilder exemplars = new StringBuilder(); for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) { String exemplarStr = String.format("\nQuestion:%s,Schema:%s,SideInfo:%s,SQL:%s", @@ -112,10 +121,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { variable.put("information", sideInformation); // use custom prompt template if provided. - String promptTemplate = INSTRUCTION; - if (StringUtils.isNotBlank(llmReq.getCustomPrompt())) { - promptTemplate = llmReq.getCustomPrompt(); - } + String promptTemplate = chatApp.getPrompt(); return PromptTemplate.from(promptTemplate).apply(variable); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index 2eba691f0..82622d932 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.headless.api.pojo.SchemaElement; @@ -10,6 +11,7 @@ import org.apache.commons.collections4.CollectionUtils; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; @@ -21,7 +23,7 @@ public class LLMReq { private String currentDate; private String priorExts; private SqlGenType sqlGenType; - private ChatModelConfig modelConfig; + private Map chatAppConfig; private String customPrompt; private List dynamicExemplars; diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index b7d95f0e3..dababd017 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -16,9 +16,11 @@ import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.build.WebBase; import com.tencent.supersonic.chat.server.plugin.build.webpage.WebPageQuery; import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuery; +import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.enums.*; +import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.DataSetDetail; import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; @@ -148,6 +150,7 @@ public class S2VisitsDemo extends S2BaseDemo { agent.setEnableSearch(1); agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长", "过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长")); + // configure tools ToolConfig toolConfig = new ToolConfig(); DatasetTool datasetTool = new DatasetTool(); @@ -157,16 +160,10 @@ public class S2VisitsDemo extends S2BaseDemo { toolConfig.getTools().add(datasetTool); agent.setToolConfig(JSONObject.toJSONString(toolConfig)); - // configure chat models - Map chatModelConfig = Maps.newHashMap(); - chatModelConfig.put(ChatModelType.TEXT_TO_SQL, demoChatModel.getId()); - chatModelConfig.put(ChatModelType.MEMORY_REVIEW, demoChatModel.getId()); - chatModelConfig.put(ChatModelType.RESPONSE_GENERATE, demoChatModel.getId()); - chatModelConfig.put(ChatModelType.MULTI_TURN_REWRITE, demoChatModel.getId()); - agent.setChatModelConfig(chatModelConfig); - - MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true); - agent.setMultiTurnConfig(multiTurnConfig); + // configure chat apps + Map chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps()); + chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId())); + agent.setChatAppConfig(chatAppConfig); Agent agentCreated = agentService.createAgent(agent, defaultUser); return agentCreated.getId(); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/SmallTalkDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/SmallTalkDemo.java index bac0927c0..13b01bfca 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/SmallTalkDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/SmallTalkDemo.java @@ -26,10 +26,6 @@ public class SmallTalkDemo extends S2BaseDemo { ToolConfig toolConfig = new ToolConfig(); agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平")); - MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); - multiTurnConfig.setEnableMultiTurn(true); - agent.setMultiTurnConfig(multiTurnConfig); - agentService.createAgent(agent, defaultUser); } diff --git a/launchers/standalone/src/main/resources/config.update/sql-update.sql b/launchers/standalone/src/main/resources/config.update/sql-update.sql index 91d7d6817..61091bf0a 100644 --- a/launchers/standalone/src/main/resources/config.update/sql-update.sql +++ b/launchers/standalone/src/main/resources/config.update/sql-update.sql @@ -387,3 +387,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_model` ( ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='对话大模型实例表'; ALTER TABLE s2_agent RENAME COLUMN config TO tool_config; ALTER TABLE s2_agent RENAME COLUMN model_config TO chat_model_config; + +--20241011 +ALTER TABLE s2_agent DROP COLUMN prompt_config; +ALTER TABLE s2_agent DROP COLUMN multi_turn_config; +ALTER TABLE s2_agent DROP COLUMN enable_memory_review; diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index f5b64c8cd..b1d500d5a 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -391,15 +391,12 @@ CREATE TABLE IF NOT EXISTS s2_agent tool_config varchar(2000) null, llm_config varchar(2000) null, chat_model_config varchar(6000) null, - prompt_config varchar(5000) null, - multi_turn_config varchar(2000) null, visual_config varchar(2000) null, created_by varchar(100) null, created_at TIMESTAMP null, updated_by varchar(100) null, updated_at TIMESTAMP null, enable_search int null, - enable_memory_review int null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_agent IS 'agent information table'; diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index 85f0834a4..f13d3190d 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -73,11 +73,8 @@ CREATE TABLE IF NOT EXISTS `s2_agent` ( `tool_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL, `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `chat_model_config` text COLLATE utf8_unicode_ci DEFAULT NULL, - `prompt_config` text COLLATE utf8_unicode_ci DEFAULT NULL, - `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `enable_search` tinyint DEFAULT 1, - `enable_memory_review` tinyint DEFAULT 0, `created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `created_at` datetime DEFAULT NULL, `updated_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 2fe2e2c58..fefb21314 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -8,7 +8,9 @@ import com.tencent.supersonic.chat.BaseTest; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.server.agent.*; import com.tencent.supersonic.chat.server.pojo.ChatModel; -import com.tencent.supersonic.common.pojo.enums.ChatModelType; +import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.util.ChatAppManager; +import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector; import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.LLMConfigUtils; import org.junit.jupiter.api.*; @@ -135,22 +137,22 @@ public class Text2SQLEval extends BaseTest { Agent agent = new Agent(); agent.setName("Agent for Test"); ToolConfig toolConfig = new ToolConfig(); - toolConfig.getTools().add(getLLMQueryTool()); + toolConfig.getTools().add(getDatasetTool()); agent.setToolConfig(JSONObject.toJSONString(toolConfig)); ChatModel chatModel = new ChatModel(); chatModel.setName("Text2SQL LLM"); chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3)); chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser()); - Map chatModelConfig = Maps.newHashMap(); - chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId()); - agent.setChatModelConfig(chatModelConfig); - MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); - multiTurnConfig.setEnableMultiTurn(enableMultiturn); - agent.setMultiTurnConfig(multiTurnConfig); + Integer chatModelId = chatModel.getId(); + // configure chat apps + Map chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps()); + chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId)); + chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(true); + agent.setChatAppConfig(chatAppConfig); return agent; } - private static DatasetTool getLLMQueryTool() { + private static DatasetTool getDatasetTool() { DatasetTool datasetTool = new DatasetTool(); datasetTool.setType(AgentToolType.DATASET); datasetTool.setDataSetIds(Lists.newArrayList(-1L)); diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index a21b0d023..92728f4ff 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -391,15 +391,12 @@ CREATE TABLE IF NOT EXISTS s2_agent tool_config varchar(2000) null, llm_config varchar(2000) null, chat_model_config varchar(6000) null, - prompt_config varchar(5000) null, - multi_turn_config varchar(2000) null, visual_config varchar(2000) null, created_by varchar(100) null, created_at TIMESTAMP null, updated_by varchar(100) null, updated_at TIMESTAMP null, enable_search int null, - enable_memory_review int null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_agent IS 'agent information table';