mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[improvement][chat]Introduce AppModule to classify chat apps.
This commit is contained in:
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
@@ -32,7 +33,7 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
public PlainTextExecutor() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("闲聊对话")
|
||||
.description("直接将原始输入透传大模型").enable(false).build());
|
||||
.appModule(AppModule.CHAT).description("直接将原始输入透传大模型").enable(false).build());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
@@ -49,7 +50,8 @@ public class MemoryReviewTask {
|
||||
private AgentService agentService;
|
||||
|
||||
public MemoryReviewTask() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估")
|
||||
ChatAppManager.register(APP_KEY,
|
||||
ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估").appModule(AppModule.CHAT)
|
||||
.description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build());
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -79,11 +80,13 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
public NL2SQLParser() {
|
||||
ChatAppManager.register(APP_KEY_MULTI_TURN,
|
||||
ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写")
|
||||
.description("通过大模型根据历史对话来改写本轮对话").enable(false).build());
|
||||
.appModule(AppModule.CHAT).description("通过大模型根据历史对话来改写本轮对话").enable(false)
|
||||
.build());
|
||||
|
||||
ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
|
||||
ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
|
||||
.description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build());
|
||||
.appModule(AppModule.CHAT).description("通过大模型将异常信息改写为更友好和引导性的提示用语")
|
||||
.enable(false).build());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
@@ -52,7 +53,7 @@ public class ChatModelController {
|
||||
|
||||
@RequestMapping("/getModelAppList")
|
||||
public Map<String, ChatApp> getChatAppList() {
|
||||
return ChatAppManager.getAllApps();
|
||||
return ChatAppManager.getAllApps(AppModule.CHAT);
|
||||
}
|
||||
|
||||
@RequestMapping("/getModelParameters")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -18,4 +19,6 @@ public class ChatApp {
|
||||
private Integer chatModelId;
|
||||
@JsonIgnore
|
||||
private ChatModelConfig chatModelConfig;
|
||||
@JsonIgnore
|
||||
private AppModule appModule;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum AppModule {
|
||||
CHAT, HEADLESS
|
||||
}
|
||||
@@ -2,8 +2,10 @@ package com.tencent.supersonic.common.util;
|
||||
|
||||
import com.google.common.collect.Maps;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ChatAppManager {
|
||||
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
|
||||
@@ -12,7 +14,8 @@ public class ChatAppManager {
|
||||
chatApps.put(key, app);
|
||||
}
|
||||
|
||||
public static Map<String, ChatApp> getAllApps() {
|
||||
return chatApps;
|
||||
public static Map<String, ChatApp> getAllApps(AppModule appType) {
|
||||
return chatApps.entrySet().stream().filter(e -> e.getValue().getAppModule().equals(appType))
|
||||
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -40,7 +41,7 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
|
||||
public LLMSqlCorrector() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
|
||||
.description("通过大模型对解析S2SQL做二次修正").enable(false).build());
|
||||
.appModule(AppModule.CHAT).description("通过大模型对解析S2SQL做二次修正").enable(false).build());
|
||||
}
|
||||
|
||||
@Data
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
@@ -44,7 +45,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
public OnePassSCSqlGenStrategy() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析")
|
||||
.description("通过大模型做语义解析生成S2SQL").enable(true).build());
|
||||
.appModule(AppModule.CHAT).description("通过大模型做语义解析生成S2SQL").enable(true).build());
|
||||
}
|
||||
|
||||
@Data
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
@@ -182,7 +183,8 @@ public class S2ArtistDemo extends S2BaseDemo {
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
|
||||
// configure chat apps
|
||||
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
|
||||
Map<String, ChatApp> chatAppConfig =
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
agentService.createAgent(agent, defaultUser);
|
||||
|
||||
@@ -157,10 +157,11 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
toolConfig.getTools().add(datasetTool);
|
||||
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
|
||||
// configure chat apps
|
||||
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
|
||||
Map<String, ChatApp> chatAppConfig =
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
Agent agentCreated = agentService.createAgent(agent, defaultUser);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.BaseApplication;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
@@ -17,6 +18,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -38,6 +40,10 @@ public class BaseTest extends BaseApplication {
|
||||
|
||||
@Value("${s2.demo.enableLLM:false}")
|
||||
protected boolean enableLLM;
|
||||
protected int agentId;
|
||||
|
||||
|
||||
protected List<Long> durations = Lists.newArrayList();
|
||||
|
||||
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
|
||||
throws Exception {
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.*;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
@@ -16,7 +17,6 @@ import com.tencent.supersonic.util.LLMConfigUtils;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.springframework.test.context.TestPropertySource;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
@@ -24,8 +24,8 @@ import java.util.Map;
|
||||
@Disabled
|
||||
public class Text2SQLEval extends BaseTest {
|
||||
|
||||
private int agentId;
|
||||
private List<Long> durations = Lists.newArrayList();
|
||||
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
|
||||
private boolean enableLLMCorrection = true;
|
||||
|
||||
@BeforeAll
|
||||
public void init() {
|
||||
@@ -139,15 +139,17 @@ public class Text2SQLEval extends BaseTest {
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
toolConfig.getTools().add(getDatasetTool());
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
// create chat model for this evaluation
|
||||
ChatModel chatModel = new ChatModel();
|
||||
chatModel.setName("Text2SQL LLM");
|
||||
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
|
||||
chatModel.setConfig(LLMConfigUtils.getLLMConfig(llmType));
|
||||
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
|
||||
Integer chatModelId = chatModel.getId();
|
||||
// configure chat apps
|
||||
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
|
||||
Map<String, ChatApp> chatAppConfig =
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId));
|
||||
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(true);
|
||||
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(enableLLMCorrection);
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
return agent;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user