[improvement][chat]Introduce AppModule to classify chat apps.

This commit is contained in:
jerryjzhang
2024-10-14 14:21:41 +08:00
parent 6a28a49d31
commit 0b71390fde
13 changed files with 50 additions and 19 deletions

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
@@ -32,7 +33,7 @@ public class PlainTextExecutor implements ChatQueryExecutor {
public PlainTextExecutor() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("闲聊对话")
.description("直接将原始输入透传大模型").enable(false).build());
.appModule(AppModule.CHAT).description("直接将原始输入透传大模型").enable(false).build());
}
@Override

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
@@ -49,7 +50,8 @@ public class MemoryReviewTask {
private AgentService agentService;
public MemoryReviewTask() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估")
ChatAppManager.register(APP_KEY,
ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估").appModule(AppModule.CHAT)
.description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build());
}

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -79,11 +80,13 @@ public class NL2SQLParser implements ChatQueryParser {
public NL2SQLParser() {
ChatAppManager.register(APP_KEY_MULTI_TURN,
ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写")
.description("通过大模型根据历史对话来改写本轮对话").enable(false).build());
.appModule(AppModule.CHAT).description("通过大模型根据历史对话来改写本轮对话").enable(false)
.build());
ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
.description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build());
.appModule(AppModule.CHAT).description("通过大模型将异常信息改写为更友好和引导性的提示用语")
.enable(false).build());
}
@Override

View File

@@ -12,6 +12,7 @@ import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@@ -52,7 +53,7 @@ public class ChatModelController {
@RequestMapping("/getModelAppList")
public Map<String, ChatApp> getChatAppList() {
return ChatAppManager.getAllApps();
return ChatAppManager.getAllApps(AppModule.CHAT);
}
@RequestMapping("/getModelParameters")

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.common.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
@@ -18,4 +19,6 @@ public class ChatApp {
private Integer chatModelId;
@JsonIgnore
private ChatModelConfig chatModelConfig;
@JsonIgnore
private AppModule appModule;
}

View File

@@ -0,0 +1,5 @@
package com.tencent.supersonic.common.pojo.enums;
public enum AppModule {
CHAT, HEADLESS
}

View File

@@ -2,8 +2,10 @@ package com.tencent.supersonic.common.util;
import com.google.common.collect.Maps;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import java.util.Map;
import java.util.stream.Collectors;
public class ChatAppManager {
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
@@ -12,7 +14,8 @@ public class ChatAppManager {
chatApps.put(key, app);
}
public static Map<String, ChatApp> getAllApps() {
return chatApps;
public static Map<String, ChatApp> getAllApps(AppModule appType) {
return chatApps.entrySet().stream().filter(e -> e.getValue().getAppModule().equals(appType))
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()));
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
@@ -40,7 +41,7 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
public LLMSqlCorrector() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
.description("通过大模型对解析S2SQL做二次修正").enable(false).build());
.appModule(AppModule.CHAT).description("通过大模型对解析S2SQL做二次修正").enable(false).build());
}
@Data

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
@@ -44,7 +45,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
public OnePassSCSqlGenStrategy() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析")
.description("通过大模型做语义解析生成S2SQL").enable(true).build());
.appModule(AppModule.CHAT).description("通过大模型做语义解析生成S2SQL").enable(true).build());
}
@Data

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.ChatAppManager;
@@ -182,7 +183,8 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
agent.setChatAppConfig(chatAppConfig);
agentService.createAgent(agent, defaultUser);

View File

@@ -157,10 +157,11 @@ public class S2VisitsDemo extends S2BaseDemo {
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
agent.setChatAppConfig(chatAppConfig);
Agent agentCreated = agentService.createAgent(agent, defaultUser);

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat;
import com.google.common.collect.Lists;
import com.tencent.supersonic.BaseApplication;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
@@ -17,6 +18,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import java.time.LocalDate;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@@ -38,6 +40,10 @@ public class BaseTest extends BaseApplication {
@Value("${s2.demo.enableLLM:false}")
protected boolean enableLLM;
protected int agentId;
protected List<Long> durations = Lists.newArrayList();
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
throws Exception {

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.*;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector;
import com.tencent.supersonic.util.DataUtils;
@@ -16,7 +17,6 @@ import com.tencent.supersonic.util.LLMConfigUtils;
import org.junit.jupiter.api.*;
import org.springframework.test.context.TestPropertySource;
import java.util.List;
import java.util.Map;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@@ -24,8 +24,8 @@ import java.util.Map;
@Disabled
public class Text2SQLEval extends BaseTest {
private int agentId;
private List<Long> durations = Lists.newArrayList();
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
private boolean enableLLMCorrection = true;
@BeforeAll
public void init() {
@@ -139,15 +139,17 @@ public class Text2SQLEval extends BaseTest {
ToolConfig toolConfig = new ToolConfig();
toolConfig.getTools().add(getDatasetTool());
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// create chat model for this evaluation
ChatModel chatModel = new ChatModel();
chatModel.setName("Text2SQL LLM");
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
chatModel.setConfig(LLMConfigUtils.getLLMConfig(llmType));
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
Integer chatModelId = chatModel.getId();
// configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId));
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(true);
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(enableLLMCorrection);
agent.setChatAppConfig(chatAppConfig);
return agent;
}