[improvement][chat]Introduce AppModule to classify chat apps.

This commit is contained in:
jerryjzhang
2024-10-14 14:21:41 +08:00
parent 6a28a49d31
commit 0b71390fde
13 changed files with 50 additions and 19 deletions

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.ChatAppManager;
@@ -182,7 +183,8 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
agent.setChatAppConfig(chatAppConfig);
agentService.createAgent(agent, defaultUser);

View File

@@ -157,10 +157,11 @@ public class S2VisitsDemo extends S2BaseDemo {
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(dataSetId));
toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
agent.setChatAppConfig(chatAppConfig);
Agent agentCreated = agentService.createAgent(agent, defaultUser);

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat;
import com.google.common.collect.Lists;
import com.tencent.supersonic.BaseApplication;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
@@ -17,6 +18,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import java.time.LocalDate;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@@ -38,6 +40,10 @@ public class BaseTest extends BaseApplication {
@Value("${s2.demo.enableLLM:false}")
protected boolean enableLLM;
protected int agentId;
protected List<Long> durations = Lists.newArrayList();
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
throws Exception {

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.*;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector;
import com.tencent.supersonic.util.DataUtils;
@@ -16,7 +17,6 @@ import com.tencent.supersonic.util.LLMConfigUtils;
import org.junit.jupiter.api.*;
import org.springframework.test.context.TestPropertySource;
import java.util.List;
import java.util.Map;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@@ -24,8 +24,8 @@ import java.util.Map;
@Disabled
public class Text2SQLEval extends BaseTest {
private int agentId;
private List<Long> durations = Lists.newArrayList();
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
private boolean enableLLMCorrection = true;
@BeforeAll
public void init() {
@@ -139,15 +139,17 @@ public class Text2SQLEval extends BaseTest {
ToolConfig toolConfig = new ToolConfig();
toolConfig.getTools().add(getDatasetTool());
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// create chat model for this evaluation
ChatModel chatModel = new ChatModel();
chatModel.setName("Text2SQL LLM");
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
chatModel.setConfig(LLMConfigUtils.getLLMConfig(llmType));
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
Integer chatModelId = chatModel.getId();
// configure chat apps
Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId));
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(true);
chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(enableLLMCorrection);
agent.setChatAppConfig(chatAppConfig);
return agent;
}