mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[feature][headless-chat]Introduce ChatApp to support more flexible chat model config.#1739
This commit is contained in:
@@ -5,7 +5,6 @@ import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -22,12 +21,8 @@ public class Agent extends RecordInfo {
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private Integer enableSearch;
|
||||
private Integer enableMemoryReview;
|
||||
private String toolConfig;
|
||||
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
|
||||
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
|
||||
private PromptConfig promptConfig;
|
||||
private MultiTurnConfig multiTurnConfig;
|
||||
private VisualConfig visualConfig;
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
@@ -49,7 +44,7 @@ public class Agent extends RecordInfo {
|
||||
}
|
||||
|
||||
public boolean enableMemoryReview() {
|
||||
return enableMemoryReview != null && enableMemoryReview == 1;
|
||||
return false;
|
||||
}
|
||||
|
||||
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||
|
||||
@@ -5,7 +5,6 @@ import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp;
|
||||
import com.tencent.supersonic.chat.server.config.ChatModelParameters;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||
@@ -13,15 +12,12 @@ import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@RestController
|
||||
@RequestMapping({"/api/chat/model", "/openapi/chat/model"})
|
||||
|
||||
@@ -5,8 +5,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.PromptConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.VisualConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
@@ -24,8 +22,6 @@ import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ChatAppManager {
|
||||
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
|
||||
|
||||
public static void register(ChatApp chatApp) {
|
||||
if (chatApps.containsKey(chatApp.getKey())) {
|
||||
throw new RuntimeException("Duplicate chat app key is disallowed.");
|
||||
}
|
||||
chatApps.put(chatApp.getKey(), chatApp);
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
return;
|
||||
}
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,
|
||||
|
||||
@@ -86,7 +86,7 @@ public class FileHandlerImpl implements FileHandler {
|
||||
}
|
||||
|
||||
private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
|
||||
DictValueReq dictValueReq) {
|
||||
DictValueReq dictValueReq) {
|
||||
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
|
||||
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
|
||||
dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent());
|
||||
@@ -118,11 +118,12 @@ public class FileHandlerImpl implements FileHandler {
|
||||
}
|
||||
|
||||
private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName,
|
||||
DictValueReq dictValueReq) {
|
||||
DictValueReq dictValueReq) {
|
||||
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
|
||||
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
|
||||
Long fileLineNum = getFileLineNum(filePath);
|
||||
Integer startLine = Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
|
||||
Integer startLine =
|
||||
Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
|
||||
Integer endLine = Integer.valueOf(
|
||||
Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + "");
|
||||
List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine);
|
||||
|
||||
@@ -59,8 +59,8 @@ public class DictTaskServiceImpl implements DictTaskService {
|
||||
private final DimensionService dimensionService;
|
||||
|
||||
public DictTaskServiceImpl(DictRepository dictRepository, DictUtils dictConverter,
|
||||
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
|
||||
DimensionService dimensionService) {
|
||||
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
|
||||
DimensionService dimensionService) {
|
||||
this.dictRepository = dictRepository;
|
||||
this.dictConverter = dictConverter;
|
||||
this.dictUtils = dictUtils;
|
||||
|
||||
@@ -75,7 +75,7 @@ public class DimensionConverter {
|
||||
}
|
||||
|
||||
public static DimensionResp convert2DimensionResp(DimensionDO dimensionDO,
|
||||
Map<Long, ModelResp> modelRespMap) {
|
||||
Map<Long, ModelResp> modelRespMap) {
|
||||
DimensionResp dimensionResp = new DimensionResp();
|
||||
BeanUtils.copyProperties(dimensionDO, dimensionResp);
|
||||
dimensionResp.setModelName(
|
||||
@@ -123,11 +123,11 @@ public class DimensionConverter {
|
||||
}
|
||||
|
||||
public static List<DimensionResp> filterByDataSet(List<DimensionResp> dimensionResps,
|
||||
DataSetResp dataSetResp) {
|
||||
DataSetResp dataSetResp) {
|
||||
return dimensionResps.stream()
|
||||
.filter(dimensionResp -> dataSetResp.dimensionIds().contains(dimensionResp.getId())
|
||||
|| dataSetResp.getAllIncludeAllModels()
|
||||
.contains(dimensionResp.getModelId()))
|
||||
.contains(dimensionResp.getModelId()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
|
||||
@@ -13,7 +13,8 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\
|
||||
com.tencent.supersonic.headless.chat.parser.QueryTypeParser
|
||||
|
||||
com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\
|
||||
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector
|
||||
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector,\
|
||||
com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector
|
||||
|
||||
com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\
|
||||
com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl
|
||||
|
||||
@@ -29,7 +29,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
@BeforeAll
|
||||
public void init() {
|
||||
Agent agent = agentService.createAgent(getLLMAgent(false), DataUtils.getUser());
|
||||
Agent agent = agentService.createAgent(getLLMAgent(), DataUtils.getUser());
|
||||
agentId = agent.getId();
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
assert result.getTextResult().contains("3");
|
||||
}
|
||||
|
||||
public Agent getLLMAgent(boolean enableMultiturn) {
|
||||
public Agent getLLMAgent() {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
|
||||
Reference in New Issue
Block a user