[feature][headless-chat]Introduce ChatApp to support more flexible chat model config.#1739

This commit is contained in:
jerryjzhang
2024-10-12 13:00:44 +08:00
parent 0cce0a76b4
commit 3501f592e7
11 changed files with 16 additions and 32 deletions

View File

@@ -5,7 +5,6 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import lombok.Data; import lombok.Data;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@@ -22,12 +21,8 @@ public class Agent extends RecordInfo {
private Integer status; private Integer status;
private List<String> examples; private List<String> examples;
private Integer enableSearch; private Integer enableSearch;
private Integer enableMemoryReview;
private String toolConfig; private String toolConfig;
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP; private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig; private VisualConfig visualConfig;
public List<String> getTools(AgentToolType type) { public List<String> getTools(AgentToolType type) {
@@ -49,7 +44,7 @@ public class Agent extends RecordInfo {
} }
public boolean enableMemoryReview() { public boolean enableMemoryReview() {
return enableMemoryReview != null && enableMemoryReview == 1; return false;
} }
public static boolean containsAllModel(Set<Long> detectViewIds) { public static boolean containsAllModel(Set<Long> detectViewIds) {

View File

@@ -5,7 +5,6 @@ import javax.servlet.http.HttpServletResponse;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp;
import com.tencent.supersonic.chat.server.config.ChatModelParameters; import com.tencent.supersonic.chat.server.config.ChatModelParameters;
import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.service.ChatModelService;
@@ -13,15 +12,12 @@ import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
@RestController @RestController
@RequestMapping({"/api/chat/model", "/openapi/chat/model"}) @RequestMapping({"/api/chat/model", "/openapi/chat/model"})

View File

@@ -5,8 +5,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.PromptConfig;
import com.tencent.supersonic.chat.server.agent.VisualConfig; import com.tencent.supersonic.chat.server.agent.VisualConfig;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
@@ -24,8 +22,6 @@ import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.stream.Collectors; import java.util.stream.Collectors;

View File

@@ -1,19 +1,14 @@
package com.tencent.supersonic.common.util; package com.tencent.supersonic.common.util;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import java.util.List;
import java.util.Map; import java.util.Map;
public class ChatAppManager { public class ChatAppManager {
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap(); private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
public static void register(ChatApp chatApp) { public static void register(ChatApp chatApp) {
if (chatApps.containsKey(chatApp.getKey())) {
throw new RuntimeException("Duplicate chat app key is disallowed.");
}
chatApps.put(chatApp.getKey(), chatApp); chatApps.put(chatApp.getKey(), chatApp);
} }

View File

@@ -64,7 +64,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
return; return;
} }
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig()); ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor = SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo, Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,

View File

@@ -86,7 +86,7 @@ public class FileHandlerImpl implements FileHandler {
} }
private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName, private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
DictValueReq dictValueReq) { DictValueReq dictValueReq) {
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>(); PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize()); dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent()); dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent());
@@ -118,11 +118,12 @@ public class FileHandlerImpl implements FileHandler {
} }
private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName, private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName,
DictValueReq dictValueReq) { DictValueReq dictValueReq) {
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>(); PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName; String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
Long fileLineNum = getFileLineNum(filePath); Long fileLineNum = getFileLineNum(filePath);
Integer startLine = Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1); Integer startLine =
Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
Integer endLine = Integer.valueOf( Integer endLine = Integer.valueOf(
Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + ""); Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + "");
List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine); List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine);

View File

@@ -59,8 +59,8 @@ public class DictTaskServiceImpl implements DictTaskService {
private final DimensionService dimensionService; private final DimensionService dimensionService;
public DictTaskServiceImpl(DictRepository dictRepository, DictUtils dictConverter, public DictTaskServiceImpl(DictRepository dictRepository, DictUtils dictConverter,
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService, DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
DimensionService dimensionService) { DimensionService dimensionService) {
this.dictRepository = dictRepository; this.dictRepository = dictRepository;
this.dictConverter = dictConverter; this.dictConverter = dictConverter;
this.dictUtils = dictUtils; this.dictUtils = dictUtils;

View File

@@ -75,7 +75,7 @@ public class DimensionConverter {
} }
public static DimensionResp convert2DimensionResp(DimensionDO dimensionDO, public static DimensionResp convert2DimensionResp(DimensionDO dimensionDO,
Map<Long, ModelResp> modelRespMap) { Map<Long, ModelResp> modelRespMap) {
DimensionResp dimensionResp = new DimensionResp(); DimensionResp dimensionResp = new DimensionResp();
BeanUtils.copyProperties(dimensionDO, dimensionResp); BeanUtils.copyProperties(dimensionDO, dimensionResp);
dimensionResp.setModelName( dimensionResp.setModelName(
@@ -123,11 +123,11 @@ public class DimensionConverter {
} }
public static List<DimensionResp> filterByDataSet(List<DimensionResp> dimensionResps, public static List<DimensionResp> filterByDataSet(List<DimensionResp> dimensionResps,
DataSetResp dataSetResp) { DataSetResp dataSetResp) {
return dimensionResps.stream() return dimensionResps.stream()
.filter(dimensionResp -> dataSetResp.dimensionIds().contains(dimensionResp.getId()) .filter(dimensionResp -> dataSetResp.dimensionIds().contains(dimensionResp.getId())
|| dataSetResp.getAllIncludeAllModels() || dataSetResp.getAllIncludeAllModels()
.contains(dimensionResp.getModelId())) .contains(dimensionResp.getModelId()))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
} }

View File

@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool; import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;

View File

@@ -13,7 +13,8 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\
com.tencent.supersonic.headless.chat.parser.QueryTypeParser com.tencent.supersonic.headless.chat.parser.QueryTypeParser
com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\ com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector,\
com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector
com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\ com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\
com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl

View File

@@ -29,7 +29,7 @@ public class Text2SQLEval extends BaseTest {
@BeforeAll @BeforeAll
public void init() { public void init() {
Agent agent = agentService.createAgent(getLLMAgent(false), DataUtils.getUser()); Agent agent = agentService.createAgent(getLLMAgent(), DataUtils.getUser());
agentId = agent.getId(); agentId = agent.getId();
} }
@@ -133,7 +133,7 @@ public class Text2SQLEval extends BaseTest {
assert result.getTextResult().contains("3"); assert result.getTextResult().contains("3");
} }
public Agent getLLMAgent(boolean enableMultiturn) { public Agent getLLMAgent() {
Agent agent = new Agent(); Agent agent = new Agent();
agent.setName("Agent for Test"); agent.setName("Agent for Test");
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();