[feature][headless-chat]Introduce ChatApp to support more flexible chat model config.#1739

This commit is contained in:
jerryjzhang
2024-10-12 13:00:44 +08:00
parent 0cce0a76b4
commit 3501f592e7
11 changed files with 16 additions and 32 deletions

View File

@@ -5,7 +5,6 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import lombok.Data;
import org.springframework.util.CollectionUtils;
@@ -22,12 +21,8 @@ public class Agent extends RecordInfo {
private Integer status;
private List<String> examples;
private Integer enableSearch;
private Integer enableMemoryReview;
private String toolConfig;
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig;
public List<String> getTools(AgentToolType type) {
@@ -49,7 +44,7 @@ public class Agent extends RecordInfo {
}
public boolean enableMemoryReview() {
return enableMemoryReview != null && enableMemoryReview == 1;
return false;
}
public static boolean containsAllModel(Set<Long> detectViewIds) {

View File

@@ -5,7 +5,6 @@ import javax.servlet.http.HttpServletResponse;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp;
import com.tencent.supersonic.chat.server.config.ChatModelParameters;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatModelService;
@@ -13,15 +12,12 @@ import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.util.ChatAppManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
@RestController
@RequestMapping({"/api/chat/model", "/openapi/chat/model"})

View File

@@ -5,8 +5,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.PromptConfig;
import com.tencent.supersonic.chat.server.agent.VisualConfig;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
@@ -24,8 +22,6 @@ import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;

View File

@@ -1,19 +1,14 @@
package com.tencent.supersonic.common.util;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tencent.supersonic.common.pojo.ChatApp;
import java.util.List;
import java.util.Map;
public class ChatAppManager {
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
public static void register(ChatApp chatApp) {
if (chatApps.containsKey(chatApp.getKey())) {
throw new RuntimeException("Duplicate chat app key is disallowed.");
}
chatApps.put(chatApp.getKey(), chatApp);
}

View File

@@ -64,7 +64,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
return;
}
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig());
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,

View File

@@ -86,7 +86,7 @@ public class FileHandlerImpl implements FileHandler {
}
private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
DictValueReq dictValueReq) {
DictValueReq dictValueReq) {
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent());
@@ -118,11 +118,12 @@ public class FileHandlerImpl implements FileHandler {
}
private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName,
DictValueReq dictValueReq) {
DictValueReq dictValueReq) {
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
Long fileLineNum = getFileLineNum(filePath);
Integer startLine = Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
Integer startLine =
Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
Integer endLine = Integer.valueOf(
Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + "");
List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine);

View File

@@ -59,8 +59,8 @@ public class DictTaskServiceImpl implements DictTaskService {
private final DimensionService dimensionService;
public DictTaskServiceImpl(DictRepository dictRepository, DictUtils dictConverter,
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
DimensionService dimensionService) {
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
DimensionService dimensionService) {
this.dictRepository = dictRepository;
this.dictConverter = dictConverter;
this.dictUtils = dictUtils;

View File

@@ -75,7 +75,7 @@ public class DimensionConverter {
}
public static DimensionResp convert2DimensionResp(DimensionDO dimensionDO,
Map<Long, ModelResp> modelRespMap) {
Map<Long, ModelResp> modelRespMap) {
DimensionResp dimensionResp = new DimensionResp();
BeanUtils.copyProperties(dimensionDO, dimensionResp);
dimensionResp.setModelName(
@@ -123,11 +123,11 @@ public class DimensionConverter {
}
public static List<DimensionResp> filterByDataSet(List<DimensionResp> dimensionResps,
DataSetResp dataSetResp) {
DataSetResp dataSetResp) {
return dimensionResps.stream()
.filter(dimensionResp -> dataSetResp.dimensionIds().contains(dimensionResp.getId())
|| dataSetResp.getAllIncludeAllModels()
.contains(dimensionResp.getModelId()))
.contains(dimensionResp.getModelId()))
.collect(Collectors.toList());
}
}

View File

@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;

View File

@@ -13,7 +13,8 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\
com.tencent.supersonic.headless.chat.parser.QueryTypeParser
com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector,\
com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector
com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\
com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl

View File

@@ -29,7 +29,7 @@ public class Text2SQLEval extends BaseTest {
@BeforeAll
public void init() {
Agent agent = agentService.createAgent(getLLMAgent(false), DataUtils.getUser());
Agent agent = agentService.createAgent(getLLMAgent(), DataUtils.getUser());
agentId = agent.getId();
}
@@ -133,7 +133,7 @@ public class Text2SQLEval extends BaseTest {
assert result.getTextResult().contains("3");
}
public Agent getLLMAgent(boolean enableMultiturn) {
public Agent getLLMAgent() {
Agent agent = new Agent();
agent.setName("Agent for Test");
ToolConfig toolConfig = new ToolConfig();