[feature][chat]Refactor chat model config related codes.#1739

This commit is contained in:
jerryjzhang
2024-10-09 17:27:07 +08:00
parent 60b0a1a1a1
commit 248f4f83f6
53 changed files with 275 additions and 251 deletions

View File

@@ -3,8 +3,6 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import lombok.Data;
@@ -17,22 +15,21 @@ import java.util.stream.Collectors;
public class Agent extends RecordInfo {
private Integer id;
private Integer enableSearch;
private Integer enableMemoryReview;
private String name;
private String description;
/** 0 offline, 1 online */
private Integer status;
private List<String> examples;
private String agentConfig;
private Map<ChatModelType, Integer> modelConfig = Collections.EMPTY_MAP;
private Integer enableSearch;
private Integer enableMemoryReview;
private String toolConfig;
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig;
public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class);
Map map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
return Lists.newArrayList();
}
@@ -84,7 +81,7 @@ public class Agent extends RecordInfo {
}
public boolean containsAnyTool() {
Map map = JSONObject.parseObject(agentConfig, Map.class);
Map map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map)) {
return false;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class PromptConfig {
private String promptTemplate;
}

View File

@@ -10,7 +10,7 @@ import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class AgentConfig {
public class ToolConfig {
List<AgentTool> tools = Lists.newArrayList();
private List<AgentTool> tools = Lists.newArrayList();
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.Data;
@Data
public class VisualConfig {
private boolean enableSimpleMode;
private boolean showDebugInfo;
}

View File

@@ -0,0 +1,127 @@
package com.tencent.supersonic.chat.server.config;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class ChatModelParameters {
private static final String MODULE_NAME = "对话模型配置";
public static final Parameter CHAT_MODEL_PROVIDER =
new Parameter("provider", ModelProvider.DEMO_CHAT_MODEL.getProvider(), "接口协议", "",
"list", MODULE_NAME, getCandidateValues());
public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter("baseUrl", ModelProvider.DEMO_CHAT_MODEL.getBaseUrl(), "BaseUrl", "",
"string", MODULE_NAME, null, getBaseUrlDependency());
public static final Parameter CHAT_MODEL_NAME =
new Parameter("modelName", ModelProvider.DEMO_CHAT_MODEL.getModelName(), "ModelName",
"", "string", MODULE_NAME, null, getModelNameDependency());
public static final Parameter CHAT_MODEL_API_KEY =
new Parameter("apiKey", ModelProvider.DEMO_CHAT_MODEL.getApiKey(), "ApiKey", "",
"password", MODULE_NAME, null, getApiKeyDependency());
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b",
"Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency());
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("secretKey", "demo",
"SecretKey", "", "password", MODULE_NAME, null, getSecretKeyDependency());
public static final Parameter CHAT_MODEL_ENABLE_SEARCH = new Parameter("enableSearch", "false",
"是否启用搜索增强功能设为false表示不启用", "", "bool", MODULE_NAME, null, getEnableSearchDependency());
public static final Parameter CHAT_MODEL_TEMPERATURE =
new Parameter("temperature", "0.0", "Temperature", "", "slider", MODULE_NAME);
public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("timeOut", "60", "超时时间(秒)", "", "number", MODULE_NAME);
public static List<Parameter> getParameters() {
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
}
private static List<String> getCandidateValues() {
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
AzureModelFactory.PROVIDER);
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL));
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), QianfanModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), ZhipuModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), LocalAiModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), AzureModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DashscopeModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
}
private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME));
}
private static List<Parameter.Dependency> getEndpointDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
}
private static List<Parameter.Dependency> getEnableSearchDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
}
private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(
QianfanModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
}
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
List<String> includesValue, Map<String, String> setDefaultValue) {
Parameter.Dependency.Show show = new Parameter.Dependency.Show();
show.setIncludesValue(includesValue);
Parameter.Dependency dependency = new Parameter.Dependency();
dependency.setName(dependencyParameterName);
dependency.setShow(show);
dependency.setSetDefaultValue(setDefaultValue);
List<Parameter.Dependency> dependencies = new ArrayList<>();
dependencies.add(dependency);
return dependencies;
}
}

View File

@@ -11,5 +11,5 @@ public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_MULTI_TURN_ENABLE =
new Parameter("s2.parser.multi-turn.enable", "false", "是否开启多轮对话", "开启多轮对话将消耗更多token",
"bool", "Parser相关配置");
"bool", "语义解析配置");
}

View File

@@ -10,43 +10,34 @@ import java.util.Date;
@Data
@TableName("s2_agent")
public class AgentDO {
/** */
@TableId(type = IdType.AUTO)
private Integer id;
/** */
private String name;
/** */
private String description;
/** 0 offline, 1 online */
private Integer status;
/** */
private String examples;
/** */
private String config;
/** */
private String createdBy;
/** */
private Date createdAt;
/** */
private String updatedBy;
/** */
private Date updatedAt;
/** */
private Integer enableSearch;
private Integer enableMemoryReview;
private String modelConfig;
private String toolConfig;
private String chatModelConfig;
private String multiTurnConfig;

View File

@@ -59,8 +59,4 @@ public class AgentController {
return AgentToolType.getToolTypes();
}
@PostMapping("/testLLMConn")
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
return ModelConfigHelper.testConnection(modelConfig);
}
}

View File

@@ -6,10 +6,12 @@ import javax.servlet.http.HttpServletResponse;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp;
import com.tencent.supersonic.chat.server.config.ChatModelParameters;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@@ -56,6 +58,11 @@ public class ChatModelController {
.collect(Collectors.toList());
}
@RequestMapping("/getModelParameters")
public List<Parameter> getModelParameters() {
return ChatModelParameters.getParameters();
}
@PostMapping("/testConnection")
public boolean testConnection(@RequestBody ChatModelConfig modelConfig) {
return ModelConfigHelper.testConnection(modelConfig);

View File

@@ -6,6 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.PromptConfig;
import com.tencent.supersonic.chat.server.agent.VisualConfig;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
@@ -14,8 +16,6 @@ import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy;
@@ -140,10 +140,10 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
}
Agent agent = new Agent();
BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig());
agent.setToolConfig(agentDO.getToolConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setModelConfig(
JsonUtil.toMap(agentDO.getModelConfig(), ChatModelType.class, Integer.class));
agent.setChatModelConfig(
JsonUtil.toMap(agentDO.getChatModelConfig(), ChatModelType.class, Integer.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
@@ -154,9 +154,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
private AgentDO convert(Agent agent) {
AgentDO agentDO = new AgentDO();
BeanUtils.copyProperties(agent, agentDO);
agentDO.setConfig(agent.getAgentConfig());
agentDO.setToolConfig(agent.getToolConfig());
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig()));
agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatModelConfig()));
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));

View File

@@ -47,6 +47,7 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
chatModelDO.setAdmin(user.getName());
}
save(chatModelDO);
chatModel.setId(chatModelDO.getId());
return chatModel;
}

View File

@@ -158,7 +158,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
executeReq.setQueryText(chatParseReq.getQueryText());
executeReq.setChatId(chatParseReq.getChatId());
executeReq.setUser(User.getFakeUser());
executeReq.setUser(User.getDefaultUser());
executeReq.setAgentId(chatParseReq.getAgentId());
executeReq.setSaveAnswer(true);
return execute(executeReq);

View File

@@ -29,8 +29,8 @@ public class ModelConfigHelper {
public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) {
ChatModelConfig chatModelConfig = null;
if (agent.getModelConfig().containsKey(modelType)) {
Integer chatModelId = agent.getModelConfig().get(modelType);
if (agent.getChatModelConfig().containsKey(modelType)) {
Integer chatModelId = agent.getChatModelConfig().get(modelType);
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig();
}

View File

@@ -50,7 +50,7 @@ public class QueryReqConverter {
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
}
queryNLReq.setModelConfig(chatModelConfig);
queryNLReq.setPromptConfig(agent.getPromptConfig());
queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate());
if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}