(improvement)(chat) Support configuring embeddingModel or embeddingStore at the agent level. (#1361)

This commit is contained in:
lexluo09
2024-07-06 20:44:23 +08:00
committed by GitHub
parent d39db734c4
commit 6db6aaf98d
42 changed files with 669 additions and 299 deletions

View File

@@ -4,7 +4,9 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.RecordInfo;
@@ -33,7 +35,9 @@ public class Agent extends RecordInfo {
private Integer status;
private List<String> examples;
private String agentConfig;
private LLMConfig llmConfig;
private ChatModelConfig llmConfig;
private ModelConfig modelConfig;
private EmbeddingStoreConfig embeddingStore;
private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig;

View File

@@ -15,7 +15,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import dev.langchain4j.provider.ModelProvider;
import java.util.Collections;
import java.util.List;
@@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig());
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatAgent.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult();

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -56,7 +56,7 @@ public class MemoryReviewTask {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(
chatAgent.getLlmConfig());
if (Objects.nonNull(chatLanguageModel)) {
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();

View File

@@ -1,14 +1,12 @@
package com.tencent.supersonic.chat.server.parser;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -26,7 +24,14 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import dev.langchain4j.provider.ModelProvider;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -34,12 +39,8 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
@Slf4j
public class NL2SQLParser implements ChatParser {
@@ -180,7 +181,7 @@ public class NL2SQLParser implements ChatParser {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(context.getLlmConfig());
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(context.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text();
@@ -242,7 +243,7 @@ public class NL2SQLParser implements ChatParser {
private String curtSchema;
private String histSchema;
private String histSQL;
private LLMConfig llmConfig;
private ChatModelConfig llmConfig;
}
}

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.ChatModelConfig;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable;
@@ -15,6 +15,7 @@ import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.List;
@@ -50,7 +51,7 @@ public class AgentController {
}
@PostMapping("/testLLMConn")
public boolean testLLMConn(@RequestBody LLMConfig llmConfig) {
public boolean testLLMConn(@RequestBody ChatModelConfig llmConfig) {
return LLMConnHelper.testConnection(llmConfig);
}

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.util.JsonUtil;
@@ -80,6 +80,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
/**
* the example in the agent will be executed by default,
* if the result is correct, it will be put into memory as a reference for LLM
*
* @param agent
*/
private void executeAgentExamplesAsync(Agent agent) {
@@ -121,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class));
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), ChatModelConfig.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));

View File

@@ -1,20 +1,20 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class LLMConnHelper {
public static boolean testConnection(LLMConfig llmConfig) {
public static boolean testConnection(ChatModelConfig chatModel) {
try {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) {
if (chatModel == null || StringUtils.isBlank(chatModel.getBaseUrl())) {
return false;
}
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(llmConfig);
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatModel);
String response = chatLanguageModel.generate("Hi there");
return StringUtils.isNotEmpty(response) ? true : false;
} catch (Exception e) {