mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
[feature][headless-chat]Introduce ChatApp to support more flexible chat model config.#1739
This commit is contained in:
@@ -5,7 +5,6 @@ import com.google.common.collect.Lists;
|
|||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -22,12 +21,8 @@ public class Agent extends RecordInfo {
|
|||||||
private Integer status;
|
private Integer status;
|
||||||
private List<String> examples;
|
private List<String> examples;
|
||||||
private Integer enableSearch;
|
private Integer enableSearch;
|
||||||
private Integer enableMemoryReview;
|
|
||||||
private String toolConfig;
|
private String toolConfig;
|
||||||
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
|
|
||||||
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
|
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
|
||||||
private PromptConfig promptConfig;
|
|
||||||
private MultiTurnConfig multiTurnConfig;
|
|
||||||
private VisualConfig visualConfig;
|
private VisualConfig visualConfig;
|
||||||
|
|
||||||
public List<String> getTools(AgentToolType type) {
|
public List<String> getTools(AgentToolType type) {
|
||||||
@@ -49,7 +44,7 @@ public class Agent extends RecordInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableMemoryReview() {
|
public boolean enableMemoryReview() {
|
||||||
return enableMemoryReview != null && enableMemoryReview == 1;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import javax.servlet.http.HttpServletResponse;
|
|||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp;
|
|
||||||
import com.tencent.supersonic.chat.server.config.ChatModelParameters;
|
import com.tencent.supersonic.chat.server.config.ChatModelParameters;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||||
@@ -13,15 +12,12 @@ import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
|||||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Parameter;
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
|
||||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@RestController
|
@RestController
|
||||||
@RequestMapping({"/api/chat/model", "/openapi/chat/model"})
|
@RequestMapping({"/api/chat/model", "/openapi/chat/model"})
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
|
||||||
import com.tencent.supersonic.chat.server.agent.PromptConfig;
|
|
||||||
import com.tencent.supersonic.chat.server.agent.VisualConfig;
|
import com.tencent.supersonic.chat.server.agent.VisualConfig;
|
||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
@@ -24,8 +22,6 @@ import org.springframework.stereotype.Service;
|
|||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|||||||
@@ -1,19 +1,14 @@
|
|||||||
package com.tencent.supersonic.common.util;
|
package com.tencent.supersonic.common.util;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
|
||||||
import com.google.common.collect.Maps;
|
import com.google.common.collect.Maps;
|
||||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class ChatAppManager {
|
public class ChatAppManager {
|
||||||
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
|
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
|
||||||
|
|
||||||
public static void register(ChatApp chatApp) {
|
public static void register(ChatApp chatApp) {
|
||||||
if (chatApps.containsKey(chatApp.getKey())) {
|
|
||||||
throw new RuntimeException("Duplicate chat app key is disallowed.");
|
|
||||||
}
|
|
||||||
chatApps.put(chatApp.getKey(), chatApp);
|
chatApps.put(chatApp.getKey(), chatApp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -64,7 +64,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
ChatLanguageModel chatLanguageModel =
|
||||||
|
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||||
SemanticSqlExtractor extractor =
|
SemanticSqlExtractor extractor =
|
||||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||||
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,
|
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ public class FileHandlerImpl implements FileHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
|
private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
|
||||||
DictValueReq dictValueReq) {
|
DictValueReq dictValueReq) {
|
||||||
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
|
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
|
||||||
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
|
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
|
||||||
dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent());
|
dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent());
|
||||||
@@ -118,11 +118,12 @@ public class FileHandlerImpl implements FileHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName,
|
private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName,
|
||||||
DictValueReq dictValueReq) {
|
DictValueReq dictValueReq) {
|
||||||
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
|
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
|
||||||
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
|
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
|
||||||
Long fileLineNum = getFileLineNum(filePath);
|
Long fileLineNum = getFileLineNum(filePath);
|
||||||
Integer startLine = Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
|
Integer startLine =
|
||||||
|
Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
|
||||||
Integer endLine = Integer.valueOf(
|
Integer endLine = Integer.valueOf(
|
||||||
Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + "");
|
Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + "");
|
||||||
List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine);
|
List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine);
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ public class DictTaskServiceImpl implements DictTaskService {
|
|||||||
private final DimensionService dimensionService;
|
private final DimensionService dimensionService;
|
||||||
|
|
||||||
public DictTaskServiceImpl(DictRepository dictRepository, DictUtils dictConverter,
|
public DictTaskServiceImpl(DictRepository dictRepository, DictUtils dictConverter,
|
||||||
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
|
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
|
||||||
DimensionService dimensionService) {
|
DimensionService dimensionService) {
|
||||||
this.dictRepository = dictRepository;
|
this.dictRepository = dictRepository;
|
||||||
this.dictConverter = dictConverter;
|
this.dictConverter = dictConverter;
|
||||||
this.dictUtils = dictUtils;
|
this.dictUtils = dictUtils;
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ public class DimensionConverter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static DimensionResp convert2DimensionResp(DimensionDO dimensionDO,
|
public static DimensionResp convert2DimensionResp(DimensionDO dimensionDO,
|
||||||
Map<Long, ModelResp> modelRespMap) {
|
Map<Long, ModelResp> modelRespMap) {
|
||||||
DimensionResp dimensionResp = new DimensionResp();
|
DimensionResp dimensionResp = new DimensionResp();
|
||||||
BeanUtils.copyProperties(dimensionDO, dimensionResp);
|
BeanUtils.copyProperties(dimensionDO, dimensionResp);
|
||||||
dimensionResp.setModelName(
|
dimensionResp.setModelName(
|
||||||
@@ -123,11 +123,11 @@ public class DimensionConverter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static List<DimensionResp> filterByDataSet(List<DimensionResp> dimensionResps,
|
public static List<DimensionResp> filterByDataSet(List<DimensionResp> dimensionResps,
|
||||||
DataSetResp dataSetResp) {
|
DataSetResp dataSetResp) {
|
||||||
return dimensionResps.stream()
|
return dimensionResps.stream()
|
||||||
.filter(dimensionResp -> dataSetResp.dimensionIds().contains(dimensionResp.getId())
|
.filter(dimensionResp -> dataSetResp.dimensionIds().contains(dimensionResp.getId())
|
||||||
|| dataSetResp.getAllIncludeAllModels()
|
|| dataSetResp.getAllIncludeAllModels()
|
||||||
.contains(dimensionResp.getModelId()))
|
.contains(dimensionResp.getModelId()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
|||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
|
||||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\
|
|||||||
com.tencent.supersonic.headless.chat.parser.QueryTypeParser
|
com.tencent.supersonic.headless.chat.parser.QueryTypeParser
|
||||||
|
|
||||||
com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\
|
com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\
|
||||||
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector
|
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector,\
|
||||||
|
com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector
|
||||||
|
|
||||||
com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\
|
com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\
|
||||||
com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl
|
com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
public void init() {
|
public void init() {
|
||||||
Agent agent = agentService.createAgent(getLLMAgent(false), DataUtils.getUser());
|
Agent agent = agentService.createAgent(getLLMAgent(), DataUtils.getUser());
|
||||||
agentId = agent.getId();
|
agentId = agent.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
assert result.getTextResult().contains("3");
|
assert result.getTextResult().contains("3");
|
||||||
}
|
}
|
||||||
|
|
||||||
public Agent getLLMAgent(boolean enableMultiturn) {
|
public Agent getLLMAgent() {
|
||||||
Agent agent = new Agent();
|
Agent agent = new Agent();
|
||||||
agent.setName("Agent for Test");
|
agent.setName("Agent for Test");
|
||||||
ToolConfig toolConfig = new ToolConfig();
|
ToolConfig toolConfig = new ToolConfig();
|
||||||
|
|||||||
Reference in New Issue
Block a user