[improvement][chat]Introduce AppModule to classify chat apps.

This commit is contained in:
jerryjzhang
2024-10-14 14:21:41 +08:00
parent 6a28a49d31
commit 0b71390fde
13 changed files with 50 additions and 19 deletions

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
@@ -32,7 +33,7 @@ public class PlainTextExecutor implements ChatQueryExecutor {
public PlainTextExecutor() { public PlainTextExecutor() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("闲聊对话") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("闲聊对话")
.description("直接将原始输入透传大模型").enable(false).build()); .appModule(AppModule.CHAT).description("直接将原始输入透传大模型").enable(false).build());
} }
@Override @Override

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
@@ -49,8 +50,9 @@ public class MemoryReviewTask {
private AgentService agentService; private AgentService agentService;
public MemoryReviewTask() { public MemoryReviewTask() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估") ChatAppManager.register(APP_KEY,
.description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build()); ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估").appModule(AppModule.CHAT)
.description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build());
} }
@Scheduled(fixedDelay = 60 * 1000) @Scheduled(fixedDelay = 60 * 1000)

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -79,11 +80,13 @@ public class NL2SQLParser implements ChatQueryParser {
public NL2SQLParser() { public NL2SQLParser() {
ChatAppManager.register(APP_KEY_MULTI_TURN, ChatAppManager.register(APP_KEY_MULTI_TURN,
ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写") ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写")
.description("通过大模型根据历史对话来改写本轮对话").enable(false).build()); .appModule(AppModule.CHAT).description("通过大模型根据历史对话来改写本轮对话").enable(false)
.build());
ChatAppManager.register(APP_KEY_ERROR_MESSAGE, ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写") ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
.description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build()); .appModule(AppModule.CHAT).description("通过大模型将异常信息改写为更友好和引导性的提示用语")
.enable(false).build());
} }
@Override @Override

View File

@@ -12,6 +12,7 @@ import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
@@ -52,7 +53,7 @@ public class ChatModelController {
@RequestMapping("/getModelAppList") @RequestMapping("/getModelAppList")
public Map<String, ChatApp> getChatAppList() { public Map<String, ChatApp> getChatAppList() {
return ChatAppManager.getAllApps(); return ChatAppManager.getAllApps(AppModule.CHAT);
} }
@RequestMapping("/getModelParameters") @RequestMapping("/getModelParameters")

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.common.pojo; package com.tencent.supersonic.common.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
@@ -18,4 +19,6 @@ public class ChatApp {
private Integer chatModelId; private Integer chatModelId;
@JsonIgnore @JsonIgnore
private ChatModelConfig chatModelConfig; private ChatModelConfig chatModelConfig;
@JsonIgnore
private AppModule appModule;
} }

View File

@@ -0,0 +1,5 @@
package com.tencent.supersonic.common.pojo.enums;
public enum AppModule {
CHAT, HEADLESS
}

View File

@@ -2,8 +2,10 @@ package com.tencent.supersonic.common.util;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
public class ChatAppManager { public class ChatAppManager {
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap(); private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
@@ -12,7 +14,8 @@ public class ChatAppManager {
chatApps.put(key, app); chatApps.put(key, app);
} }
public static Map<String, ChatApp> getAllApps() { public static Map<String, ChatApp> getAllApps(AppModule appType) {
return chatApps; return chatApps.entrySet().stream().filter(e -> e.getValue().getAppModule().equals(appType))
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()));
} }
} }

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
@@ -40,7 +41,7 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
public LLMSqlCorrector() { public LLMSqlCorrector() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
.description("通过大模型对解析S2SQL做二次修正").enable(false).build()); .appModule(AppModule.CHAT).description("通过大模型对解析S2SQL做二次修正").enable(false).build());
} }
@Data @Data

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
@@ -44,7 +45,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
public OnePassSCSqlGenStrategy() { public OnePassSCSqlGenStrategy() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析")
.description("通过大模型做语义解析生成S2SQL").enable(true).build()); .appModule(AppModule.CHAT).description("通过大模型做语义解析生成S2SQL").enable(true).build());
} }
@Data @Data

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool; import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
@@ -182,7 +183,8 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat apps // configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps()); Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId())); chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
agent.setChatAppConfig(chatAppConfig); agent.setChatAppConfig(chatAppConfig);
agentService.createAgent(agent, defaultUser); agentService.createAgent(agent, defaultUser);

View File

@@ -157,10 +157,11 @@ public class S2VisitsDemo extends S2BaseDemo {
datasetTool.setType(AgentToolType.DATASET); datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId)); datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(datasetTool); toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat apps // configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps()); Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId())); chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
agent.setChatAppConfig(chatAppConfig); agent.setChatAppConfig(chatAppConfig);
Agent agentCreated = agentService.createAgent(agent, defaultUser); Agent agentCreated = agentService.createAgent(agent, defaultUser);

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat; package com.tencent.supersonic.chat;
import com.google.common.collect.Lists;
import com.tencent.supersonic.BaseApplication; import com.tencent.supersonic.BaseApplication;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
@@ -17,6 +18,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -38,6 +40,10 @@ public class BaseTest extends BaseApplication {
@Value("${s2.demo.enableLLM:false}") @Value("${s2.demo.enableLLM:false}")
protected boolean enableLLM; protected boolean enableLLM;
protected int agentId;
protected List<Long> durations = Lists.newArrayList();
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
throws Exception { throws Exception {

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.*; import com.tencent.supersonic.chat.server.agent.*;
import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector; import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
@@ -16,7 +17,6 @@ import com.tencent.supersonic.util.LLMConfigUtils;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.TestPropertySource;
import java.util.List;
import java.util.Map; import java.util.Map;
@TestInstance(TestInstance.Lifecycle.PER_CLASS) @TestInstance(TestInstance.Lifecycle.PER_CLASS)
@@ -24,8 +24,8 @@ import java.util.Map;
@Disabled @Disabled
public class Text2SQLEval extends BaseTest { public class Text2SQLEval extends BaseTest {
private int agentId; private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
private List<Long> durations = Lists.newArrayList(); private boolean enableLLMCorrection = true;
@BeforeAll @BeforeAll
public void init() { public void init() {
@@ -139,15 +139,17 @@ public class Text2SQLEval extends BaseTest {
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();
toolConfig.getTools().add(getDatasetTool()); toolConfig.getTools().add(getDatasetTool());
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// create chat model for this evaluation
ChatModel chatModel = new ChatModel(); ChatModel chatModel = new ChatModel();
chatModel.setName("Text2SQL LLM"); chatModel.setName("Text2SQL LLM");
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3)); chatModel.setConfig(LLMConfigUtils.getLLMConfig(llmType));
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser()); chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
Integer chatModelId = chatModel.getId(); Integer chatModelId = chatModel.getId();
// configure chat apps // configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps()); Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId)); chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId));
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(true); chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(enableLLMCorrection);
agent.setChatAppConfig(chatAppConfig); agent.setChatAppConfig(chatAppConfig);
return agent; return agent;
} }