mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[feature][chat]Refactor chat model config related codes.#1739
This commit is contained in:
@@ -3,8 +3,6 @@ package com.tencent.supersonic.chat.server.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import lombok.Data;
|
||||
@@ -17,22 +15,21 @@ import java.util.stream.Collectors;
|
||||
public class Agent extends RecordInfo {
|
||||
|
||||
private Integer id;
|
||||
private Integer enableSearch;
|
||||
private Integer enableMemoryReview;
|
||||
private String name;
|
||||
private String description;
|
||||
|
||||
/** 0 offline, 1 online */
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
private Map<ChatModelType, Integer> modelConfig = Collections.EMPTY_MAP;
|
||||
private Integer enableSearch;
|
||||
private Integer enableMemoryReview;
|
||||
private String toolConfig;
|
||||
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
|
||||
private PromptConfig promptConfig;
|
||||
private MultiTurnConfig multiTurnConfig;
|
||||
private VisualConfig visualConfig;
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||
Map map = JSONObject.parseObject(toolConfig, Map.class);
|
||||
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
@@ -84,7 +81,7 @@ public class Agent extends RecordInfo {
|
||||
}
|
||||
|
||||
public boolean containsAnyTool() {
|
||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||
Map map = JSONObject.parseObject(toolConfig, Map.class);
|
||||
if (CollectionUtils.isEmpty(map)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class PromptConfig {
|
||||
|
||||
private String promptTemplate;
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import java.util.List;
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class AgentConfig {
|
||||
public class ToolConfig {
|
||||
|
||||
List<AgentTool> tools = Lists.newArrayList();
|
||||
private List<AgentTool> tools = Lists.newArrayList();
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class VisualConfig {
|
||||
|
||||
private boolean enableSimpleMode;
|
||||
|
||||
private boolean showDebugInfo;
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package com.tencent.supersonic.chat.server.config;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ChatModelParameters {
|
||||
private static final String MODULE_NAME = "对话模型配置";
|
||||
|
||||
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||
new Parameter("provider", ModelProvider.DEMO_CHAT_MODEL.getProvider(), "接口协议", "",
|
||||
"list", MODULE_NAME, getCandidateValues());
|
||||
|
||||
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||
new Parameter("baseUrl", ModelProvider.DEMO_CHAT_MODEL.getBaseUrl(), "BaseUrl", "",
|
||||
"string", MODULE_NAME, null, getBaseUrlDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_NAME =
|
||||
new Parameter("modelName", ModelProvider.DEMO_CHAT_MODEL.getModelName(), "ModelName",
|
||||
"", "string", MODULE_NAME, null, getModelNameDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_API_KEY =
|
||||
new Parameter("apiKey", ModelProvider.DEMO_CHAT_MODEL.getApiKey(), "ApiKey", "",
|
||||
"password", MODULE_NAME, null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b",
|
||||
"Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("secretKey", "demo",
|
||||
"SecretKey", "", "password", MODULE_NAME, null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_ENABLE_SEARCH = new Parameter("enableSearch", "false",
|
||||
"是否启用搜索增强功能,设为false表示不启用", "", "bool", MODULE_NAME, null, getEnableSearchDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||
new Parameter("temperature", "0.0", "Temperature", "", "slider", MODULE_NAME);
|
||||
|
||||
public static final Parameter CHAT_MODEL_TIMEOUT =
|
||||
new Parameter("timeOut", "60", "超时时间(秒)", "", "number", MODULE_NAME);
|
||||
|
||||
public static List<Parameter> getParameters() {
|
||||
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
||||
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
|
||||
}
|
||||
|
||||
private static List<String> getCandidateValues() {
|
||||
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
||||
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), QianfanModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), ZhipuModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), LocalAiModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), AzureModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DashscopeModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
|
||||
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
||||
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(
|
||||
QianfanModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
|
||||
List<String> includesValue, Map<String, String> setDefaultValue) {
|
||||
|
||||
Parameter.Dependency.Show show = new Parameter.Dependency.Show();
|
||||
show.setIncludesValue(includesValue);
|
||||
|
||||
Parameter.Dependency dependency = new Parameter.Dependency();
|
||||
dependency.setName(dependencyParameterName);
|
||||
dependency.setShow(show);
|
||||
dependency.setSetDefaultValue(setDefaultValue);
|
||||
List<Parameter.Dependency> dependencies = new ArrayList<>();
|
||||
dependencies.add(dependency);
|
||||
return dependencies;
|
||||
}
|
||||
}
|
||||
@@ -11,5 +11,5 @@ public class ParserConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter PARSER_MULTI_TURN_ENABLE =
|
||||
new Parameter("s2.parser.multi-turn.enable", "false", "是否开启多轮对话", "开启多轮对话将消耗更多token",
|
||||
"bool", "Parser相关配置");
|
||||
"bool", "语义解析配置");
|
||||
}
|
||||
|
||||
@@ -10,43 +10,34 @@ import java.util.Date;
|
||||
@Data
|
||||
@TableName("s2_agent")
|
||||
public class AgentDO {
|
||||
/** */
|
||||
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Integer id;
|
||||
|
||||
/** */
|
||||
private String name;
|
||||
|
||||
/** */
|
||||
private String description;
|
||||
|
||||
/** 0 offline, 1 online */
|
||||
private Integer status;
|
||||
|
||||
/** */
|
||||
private String examples;
|
||||
|
||||
/** */
|
||||
private String config;
|
||||
|
||||
/** */
|
||||
private String createdBy;
|
||||
|
||||
/** */
|
||||
private Date createdAt;
|
||||
|
||||
/** */
|
||||
private String updatedBy;
|
||||
|
||||
/** */
|
||||
private Date updatedAt;
|
||||
|
||||
/** */
|
||||
private Integer enableSearch;
|
||||
|
||||
private Integer enableMemoryReview;
|
||||
|
||||
private String modelConfig;
|
||||
private String toolConfig;
|
||||
|
||||
private String chatModelConfig;
|
||||
|
||||
private String multiTurnConfig;
|
||||
|
||||
|
||||
@@ -59,8 +59,4 @@ public class AgentController {
|
||||
return AgentToolType.getToolTypes();
|
||||
}
|
||||
|
||||
@PostMapping("/testLLMConn")
|
||||
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
|
||||
return ModelConfigHelper.testConnection(modelConfig);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,10 +6,12 @@ import javax.servlet.http.HttpServletResponse;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp;
|
||||
import com.tencent.supersonic.chat.server.config.ChatModelParameters;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
@@ -56,6 +58,11 @@ public class ChatModelController {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@RequestMapping("/getModelParameters")
|
||||
public List<Parameter> getModelParameters() {
|
||||
return ChatModelParameters.getParameters();
|
||||
}
|
||||
|
||||
@PostMapping("/testConnection")
|
||||
public boolean testConnection(@RequestBody ChatModelConfig modelConfig) {
|
||||
return ModelConfigHelper.testConnection(modelConfig);
|
||||
|
||||
@@ -6,6 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.PromptConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.VisualConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||
@@ -14,8 +16,6 @@ import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy;
|
||||
@@ -140,10 +140,10 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
}
|
||||
Agent agent = new Agent();
|
||||
BeanUtils.copyProperties(agentDO, agent);
|
||||
agent.setAgentConfig(agentDO.getConfig());
|
||||
agent.setToolConfig(agentDO.getToolConfig());
|
||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||
agent.setModelConfig(
|
||||
JsonUtil.toMap(agentDO.getModelConfig(), ChatModelType.class, Integer.class));
|
||||
agent.setChatModelConfig(
|
||||
JsonUtil.toMap(agentDO.getChatModelConfig(), ChatModelType.class, Integer.class));
|
||||
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||
agent.setMultiTurnConfig(
|
||||
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||
@@ -154,9 +154,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
private AgentDO convert(Agent agent) {
|
||||
AgentDO agentDO = new AgentDO();
|
||||
BeanUtils.copyProperties(agent, agentDO);
|
||||
agentDO.setConfig(agent.getAgentConfig());
|
||||
agentDO.setToolConfig(agent.getToolConfig());
|
||||
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
||||
agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig()));
|
||||
agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatModelConfig()));
|
||||
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
|
||||
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
||||
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));
|
||||
|
||||
@@ -47,6 +47,7 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
||||
chatModelDO.setAdmin(user.getName());
|
||||
}
|
||||
save(chatModelDO);
|
||||
chatModel.setId(chatModelDO.getId());
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
||||
executeReq.setQueryText(chatParseReq.getQueryText());
|
||||
executeReq.setChatId(chatParseReq.getChatId());
|
||||
executeReq.setUser(User.getFakeUser());
|
||||
executeReq.setUser(User.getDefaultUser());
|
||||
executeReq.setAgentId(chatParseReq.getAgentId());
|
||||
executeReq.setSaveAnswer(true);
|
||||
return execute(executeReq);
|
||||
|
||||
@@ -29,8 +29,8 @@ public class ModelConfigHelper {
|
||||
|
||||
public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) {
|
||||
ChatModelConfig chatModelConfig = null;
|
||||
if (agent.getModelConfig().containsKey(modelType)) {
|
||||
Integer chatModelId = agent.getModelConfig().get(modelType);
|
||||
if (agent.getChatModelConfig().containsKey(modelType)) {
|
||||
Integer chatModelId = agent.getChatModelConfig().get(modelType);
|
||||
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
|
||||
chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig();
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ public class QueryReqConverter {
|
||||
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
|
||||
}
|
||||
queryNLReq.setModelConfig(chatModelConfig);
|
||||
queryNLReq.setPromptConfig(agent.getPromptConfig());
|
||||
queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate());
|
||||
if (chatCtx != null) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user