mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-01-01 06:47:47 +08:00
[improvement][chat]Introduce AppModule to classify chat apps.
This commit is contained in:
@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
@@ -182,7 +183,8 @@ public class S2ArtistDemo extends S2BaseDemo {
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
|
||||
// configure chat apps
|
||||
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
|
||||
Map<String, ChatApp> chatAppConfig =
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
agentService.createAgent(agent, defaultUser);
|
||||
|
||||
@@ -157,10 +157,11 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
toolConfig.getTools().add(datasetTool);
|
||||
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
|
||||
// configure chat apps
|
||||
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
|
||||
Map<String, ChatApp> chatAppConfig =
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
Agent agentCreated = agentService.createAgent(agent, defaultUser);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.BaseApplication;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
@@ -17,6 +18,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -38,6 +40,10 @@ public class BaseTest extends BaseApplication {
|
||||
|
||||
@Value("${s2.demo.enableLLM:false}")
|
||||
protected boolean enableLLM;
|
||||
protected int agentId;
|
||||
|
||||
|
||||
protected List<Long> durations = Lists.newArrayList();
|
||||
|
||||
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
|
||||
throws Exception {
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.*;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
@@ -16,7 +17,6 @@ import com.tencent.supersonic.util.LLMConfigUtils;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.springframework.test.context.TestPropertySource;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
@@ -24,8 +24,8 @@ import java.util.Map;
|
||||
@Disabled
|
||||
public class Text2SQLEval extends BaseTest {
|
||||
|
||||
private int agentId;
|
||||
private List<Long> durations = Lists.newArrayList();
|
||||
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
|
||||
private boolean enableLLMCorrection = true;
|
||||
|
||||
@BeforeAll
|
||||
public void init() {
|
||||
@@ -139,15 +139,17 @@ public class Text2SQLEval extends BaseTest {
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
toolConfig.getTools().add(getDatasetTool());
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
// create chat model for this evaluation
|
||||
ChatModel chatModel = new ChatModel();
|
||||
chatModel.setName("Text2SQL LLM");
|
||||
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
|
||||
chatModel.setConfig(LLMConfigUtils.getLLMConfig(llmType));
|
||||
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
|
||||
Integer chatModelId = chatModel.getId();
|
||||
// configure chat apps
|
||||
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
|
||||
Map<String, ChatApp> chatAppConfig =
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId));
|
||||
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(true);
|
||||
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(enableLLMCorrection);
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
return agent;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user