From 0b71390fde8e03b8e9e8479bd9937edba5805734 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Mon, 14 Oct 2024 14:21:41 +0800 Subject: [PATCH] [improvement][chat]Introduce AppModule to classify chat apps. --- .../chat/server/executor/PlainTextExecutor.java | 3 ++- .../chat/server/memory/MemoryReviewTask.java | 6 ++++-- .../chat/server/parser/NL2SQLParser.java | 7 +++++-- .../chat/server/rest/ChatModelController.java | 3 ++- .../tencent/supersonic/common/pojo/ChatApp.java | 3 +++ .../supersonic/common/pojo/enums/AppModule.java | 5 +++++ .../supersonic/common/util/ChatAppManager.java | 7 +++++-- .../headless/chat/corrector/LLMSqlCorrector.java | 3 ++- .../chat/parser/llm/OnePassSCSqlGenStrategy.java | 3 ++- .../com/tencent/supersonic/demo/S2ArtistDemo.java | 4 +++- .../com/tencent/supersonic/demo/S2VisitsDemo.java | 5 +++-- .../java/com/tencent/supersonic/chat/BaseTest.java | 6 ++++++ .../supersonic/evaluation/Text2SQLEval.java | 14 ++++++++------ 13 files changed, 50 insertions(+), 19 deletions(-) create mode 100644 common/src/main/java/com/tencent/supersonic/common/pojo/enums/AppModule.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index d6e7d84ff..74d7f4bbc 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.response.QueryState; @@ -32,7 +33,7 @@ public class PlainTextExecutor implements ChatQueryExecutor { public PlainTextExecutor() { ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("闲聊对话") - .description("直接将原始输入透传大模型").enable(false).build()); + .appModule(AppModule.CHAT).description("直接将原始输入透传大模型").enable(false).build()); } @Override diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index f924d92c8..84564246b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.util.ChatAppManager; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.Prompt; @@ -49,8 +50,9 @@ public class MemoryReviewTask { private AgentService agentService; public MemoryReviewTask() { - ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估") - .description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build()); + ChatAppManager.register(APP_KEY, + ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估").appModule(AppModule.CHAT) + .description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build()); } @Scheduled(fixedDelay = 60 * 1000) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 0c83bcbd8..2e409824f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; +import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ContextUtils; @@ -79,11 +80,13 @@ public class NL2SQLParser implements ChatQueryParser { public NL2SQLParser() { ChatAppManager.register(APP_KEY_MULTI_TURN, ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写") - .description("通过大模型根据历史对话来改写本轮对话").enable(false).build()); + .appModule(AppModule.CHAT).description("通过大模型根据历史对话来改写本轮对话").enable(false) + .build()); ChatAppManager.register(APP_KEY_ERROR_MESSAGE, ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写") - .description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build()); + .appModule(AppModule.CHAT).description("通过大模型将异常信息改写为更友好和引导性的提示用语") + .enable(false).build()); } @Override diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java index 8dafefd1c..ccb7751d9 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatModelController.java @@ -12,6 +12,7 @@ import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Parameter; +import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.util.ChatAppManager; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; @@ -52,7 +53,7 @@ public class ChatModelController { @RequestMapping("/getModelAppList") public Map getChatAppList() { - return ChatAppManager.getAllApps(); + return ChatAppManager.getAllApps(AppModule.CHAT); } @RequestMapping("/getModelParameters") diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java index 111805d01..0a9e8e93b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.common.pojo; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.tencent.supersonic.common.pojo.enums.AppModule; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -18,4 +19,6 @@ public class ChatApp { private Integer chatModelId; @JsonIgnore private ChatModelConfig chatModelConfig; + @JsonIgnore + private AppModule appModule; } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AppModule.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AppModule.java new file mode 100644 index 000000000..a65fc23b6 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AppModule.java @@ -0,0 +1,5 @@ +package com.tencent.supersonic.common.pojo.enums; + +public enum AppModule { + CHAT, HEADLESS +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/ChatAppManager.java b/common/src/main/java/com/tencent/supersonic/common/util/ChatAppManager.java index 8a1130a2b..4e94a4d3c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/ChatAppManager.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/ChatAppManager.java @@ -2,8 +2,10 @@ package com.tencent.supersonic.common.util; import com.google.common.collect.Maps; import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.enums.AppModule; import java.util.Map; +import java.util.stream.Collectors; public class ChatAppManager { private static final Map chatApps = Maps.newConcurrentMap(); @@ -12,7 +14,8 @@ public class ChatAppManager { chatApps.put(key, app); } - public static Map getAllApps() { - return chatApps; + public static Map getAllApps(AppModule appType) { + return chatApps.entrySet().stream().filter(e -> e.getValue().getAppModule().equals(appType)) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java index 5d3248c98..b961f243a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; @@ -40,7 +41,7 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { public LLMSqlCorrector() { ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正") - .description("通过大模型对解析S2SQL做二次修正").enable(false).build()); + .appModule(AppModule.CHAT).description("通过大模型对解析S2SQL做二次修正").enable(false).build()); } @Data diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index b2d20c559..29129322d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; +import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; @@ -44,7 +45,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { public OnePassSCSqlGenStrategy() { ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析") - .description("通过大模型做语义解析生成S2SQL").enable(true).build()); + .appModule(AppModule.CHAT).description("通过大模型做语义解析生成S2SQL").enable(true).build()); } @Data diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java index 66fb67468..a09e79ab4 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java @@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.DatasetTool; import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.util.ChatAppManager; @@ -182,7 +183,8 @@ public class S2ArtistDemo extends S2BaseDemo { agent.setToolConfig(JSONObject.toJSONString(toolConfig)); // configure chat apps - Map chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps()); + Map chatAppConfig = + Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT)); chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId())); agent.setChatAppConfig(chatAppConfig); agentService.createAgent(agent, defaultUser); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 2b8affed7..ba1b433f3 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -157,10 +157,11 @@ public class S2VisitsDemo extends S2BaseDemo { datasetTool.setType(AgentToolType.DATASET); datasetTool.setDataSetIds(Lists.newArrayList(dataSetId)); toolConfig.getTools().add(datasetTool); - agent.setToolConfig(JSONObject.toJSONString(toolConfig)); + // configure chat apps - Map chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps()); + Map chatAppConfig = + Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT)); chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId())); agent.setChatAppConfig(chatAppConfig); Agent agentCreated = agentService.createAgent(agent, defaultUser); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java index 63e69e827..7ed95e65c 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat; +import com.google.common.collect.Lists; import com.tencent.supersonic.BaseApplication; import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; @@ -17,6 +18,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import java.time.LocalDate; +import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -38,6 +40,10 @@ public class BaseTest extends BaseApplication { @Value("${s2.demo.enableLLM:false}") protected boolean enableLLM; + protected int agentId; + + + protected List durations = Lists.newArrayList(); protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index d16b82685..bb4f89d17 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.server.agent.*; import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector; import com.tencent.supersonic.util.DataUtils; @@ -16,7 +17,6 @@ import com.tencent.supersonic.util.LLMConfigUtils; import org.junit.jupiter.api.*; import org.springframework.test.context.TestPropertySource; -import java.util.List; import java.util.Map; @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -24,8 +24,8 @@ import java.util.Map; @Disabled public class Text2SQLEval extends BaseTest { - private int agentId; - private List durations = Lists.newArrayList(); + private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3; + private boolean enableLLMCorrection = true; @BeforeAll public void init() { @@ -139,15 +139,17 @@ public class Text2SQLEval extends BaseTest { ToolConfig toolConfig = new ToolConfig(); toolConfig.getTools().add(getDatasetTool()); agent.setToolConfig(JSONObject.toJSONString(toolConfig)); + // create chat model for this evaluation ChatModel chatModel = new ChatModel(); chatModel.setName("Text2SQL LLM"); - chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3)); + chatModel.setConfig(LLMConfigUtils.getLLMConfig(llmType)); chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser()); Integer chatModelId = chatModel.getId(); // configure chat apps - Map chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps()); + Map chatAppConfig = + Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT)); chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId)); - chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(true); + chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(enableLLMCorrection); agent.setChatAppConfig(chatAppConfig); return agent; }