(improvement)(chat) Support embedding store configuration. (#1363)

This commit is contained in:
lexluo09
2024-07-07 00:30:19 +08:00
committed by GitHub
parent 3f460429e6
commit 4d7bfe07aa
37 changed files with 185 additions and 119 deletions

View File

@@ -4,8 +4,6 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
@@ -35,9 +33,7 @@ public class Agent extends RecordInfo {
private Integer status;
private List<String> examples;
private String agentConfig;
private ChatModelConfig llmConfig;
private ModelConfig modelConfig;
private EmbeddingStoreConfig embeddingStore;
private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig;

View File

@@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatAgent.getLlmConfig());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult();

View File

@@ -56,8 +56,8 @@ public class MemoryReviewTask {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(
chatAgent.getLlmConfig());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
chatAgent.getModelConfig());
if (Objects.nonNull(chatLanguageModel)) {
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);

View File

@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -168,7 +168,7 @@ public class NL2SQLParser implements ChatParser {
.curtSchema(curtMapStr)
.histSchema(histMapStr)
.histSQL(histSQL)
.llmConfig(queryTextReq.getLlmConfig())
.modelConfig(queryTextReq.getModelConfig())
.build());
chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
@@ -181,7 +181,7 @@ public class NL2SQLParser implements ChatParser {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(context.getLlmConfig());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(context.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text();
@@ -243,7 +243,7 @@ public class NL2SQLParser implements ChatParser {
private String curtSchema;
private String histSchema;
private String histSQL;
private ChatModelConfig llmConfig;
private ModelConfig modelConfig;
}
}

View File

@@ -11,15 +11,18 @@ import java.util.Date;
@TableName("s2_agent")
public class AgentDO {
/**
*
*/
@TableId(type = IdType.AUTO)
private Integer id;
/**
*
*/
private String name;
/**
*
*/
private String description;
@@ -29,35 +32,40 @@ public class AgentDO {
private Integer status;
/**
*
*/
private String examples;
/**
*
*/
private String config;
/**
*
*/
private String createdBy;
/**
*
*/
private Date createdAt;
/**
*
*/
private String updatedBy;
/**
*
*/
private Date updatedAt;
/**
*
*/
private Integer enableSearch;
private String llmConfig;
private String modelConfig;
private String multiTurnConfig;
private String visualConfig;

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable;
@@ -30,16 +30,16 @@ public class AgentController {
@PostMapping
public Agent createAgent(@RequestBody Agent agent,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return agentService.createAgent(agent, user);
}
@PutMapping
public Agent updateAgent(@RequestBody Agent agent,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return agentService.updateAgent(agent, user);
}
@@ -51,8 +51,8 @@ public class AgentController {
}
@PostMapping("/testLLMConn")
public boolean testLLMConn(@RequestBody ChatModelConfig llmConfig) {
return LLMConnHelper.testConnection(llmConfig);
public boolean testLLMConn(@RequestBody ModelConfig modelConfig) {
return LLMConnHelper.testConnection(modelConfig);
}
@RequestMapping("/getAgentList")

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.util.JsonUtil;
@@ -88,7 +88,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
}
private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getLlmConfig())
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
|| CollectionUtils.isEmpty(agent.getExamples())) {
return;
}
@@ -122,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), ChatModelConfig.class));
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ModelConfig.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
@@ -134,7 +134,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
BeanUtils.copyProperties(agent, agentDO);
agentDO.setConfig(agent.getAgentConfig());
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig()));
agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig()));
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.provider.ModelProvider;
@@ -9,12 +9,13 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j
public class LLMConnHelper {
public static boolean testConnection(ChatModelConfig chatModel) {
public static boolean testConnection(ModelConfig modelConfig) {
try {
if (chatModel == null || StringUtils.isBlank(chatModel.getBaseUrl())) {
if (modelConfig == null || modelConfig.getChatModel() == null
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
return false;
}
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatModel);
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
String response = chatLanguageModel.generate("Hi there");
return StringUtils.isNotEmpty(response) ? true : false;
} catch (Exception e) {

View File

@@ -30,7 +30,7 @@ public class QueryReqConverter {
&& MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) {
queryTextReq.setMapInfo(queryTextReq.getMapInfo());
}
queryTextReq.setLlmConfig(agent.getLlmConfig());
queryTextReq.setModelConfig(agent.getModelConfig());
queryTextReq.setPromptConfig(agent.getPromptConfig());
return queryTextReq;
}