From 560c26fbf33efdc919f70ea96ce6ce872757c71c Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Thu, 27 Jun 2024 22:33:24 +0800 Subject: [PATCH] (improvement)(Chat) Refactor agent examples execution code (#1258) Co-authored-by: lxwcodemonkey --- .../AgentExample2MemoryTransformer.java | 49 ----------------- .../server/service/impl/AgentServiceImpl.java | 53 +++++++++++++++++-- .../src/main/resources/db/schema-mysql.sql | 20 +++---- 3 files changed, 59 insertions(+), 63 deletions(-) delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/AgentExample2MemoryTransformer.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/AgentExample2MemoryTransformer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/AgentExample2MemoryTransformer.java deleted file mode 100644 index 7693922b9..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/AgentExample2MemoryTransformer.java +++ /dev/null @@ -1,49 +0,0 @@ -package com.tencent.supersonic.chat.server.memory; - -import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; -import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; -import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; -import com.tencent.supersonic.chat.server.service.ChatService; -import com.tencent.supersonic.chat.server.service.MemoryService; -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.scheduling.annotation.Async; -import org.springframework.stereotype.Component; -import java.util.List; -import java.util.stream.Collectors; - -@Component -@Slf4j -public class AgentExample2MemoryTransformer { - - @Autowired - private ChatService chatService; - - @Autowired - private MemoryService memoryService; - - @Async - public void transform(Agent agent) { - if (!agent.containsLLMParserTool() || agent.getLlmConfig() == null) { - return; - } - List examples = agent.getExamples(); - ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().questions(examples).build(); - List memoriesExisted = memoryService.getMemories(chatMemoryFilter) - .stream().map(ChatMemoryDO::getQuestion).collect(Collectors.toList()); - for (String example : examples) { - if (memoriesExisted.contains(example)) { - continue; - } - ChatParseReq chatParseReq = new ChatParseReq(); - chatParseReq.setAgentId(agent.getId()); - chatParseReq.setQueryText(example); - chatParseReq.setUser(User.getFakeUser()); - chatParseReq.setChatId(-1); - chatService.parseAndExecute(chatParseReq); - } - } - -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 6dba0686f..f7676d663 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -2,27 +2,42 @@ package com.tencent.supersonic.chat.server.service.impl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; +import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; -import com.tencent.supersonic.chat.server.memory.AgentExample2MemoryTransformer; import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper; import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.chat.server.service.ChatService; +import com.tencent.supersonic.chat.server.service.MemoryService; +import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.util.JsonUtil; +import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; + import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.stream.Collectors; +@Slf4j @Service public class AgentServiceImpl extends ServiceImpl implements AgentService { @Autowired - private AgentExample2MemoryTransformer agentExample2MemoryTransformer; + private MemoryService memoryService; + + @Autowired + private ChatService chatService; + + private ExecutorService executorService = Executors.newFixedThreadPool(1); @Override public List getAgents() { @@ -35,7 +50,7 @@ public class AgentServiceImpl extends ServiceImpl agent.createdBy(user.getName()); AgentDO agentDO = convert(agent); save(agentDO); - agentExample2MemoryTransformer.transform(agent); + executeAgentExamplesAsync(agent); return agentDO.getId(); } @@ -43,7 +58,7 @@ public class AgentServiceImpl extends ServiceImpl public void updateAgent(Agent agent, User user) { agent.updatedBy(user.getName()); updateById(convert(agent)); - agentExample2MemoryTransformer.transform(agent); + executeAgentExamplesAsync(agent); } @Override @@ -59,6 +74,36 @@ public class AgentServiceImpl extends ServiceImpl removeById(id); } + /** + * the example in the agent will be executed by default, + * if the result is correct, it will be put into memory as a reference for LLM + * @param agent + */ + private void executeAgentExamplesAsync(Agent agent) { + executorService.execute(() -> doExecuteAgentExamples(agent)); + } + + private synchronized void doExecuteAgentExamples(Agent agent) { + if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getLlmConfig())) { + return; + } + List examples = agent.getExamples(); + ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().questions(examples).build(); + List memoriesExisted = memoryService.getMemories(chatMemoryFilter) + .stream().map(ChatMemoryDO::getQuestion).collect(Collectors.toList()); + for (String example : examples) { + if (memoriesExisted.contains(example)) { + continue; + } + ChatParseReq chatParseReq = new ChatParseReq(); + chatParseReq.setAgentId(agent.getId()); + chatParseReq.setQueryText(example); + chatParseReq.setUser(User.getFakeUser()); + chatParseReq.setChatId(-1); + chatService.parseAndExecute(chatParseReq); + } + } + private List getAgentDOList() { return list(); } diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index c04ed2f52..c90e916f6 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -138,16 +138,16 @@ CREATE TABLE `s2_chat_config` ( CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( `id` INT NOT NULL AUTO_INCREMENT, - `question` varchar(655) , - `agent_id` INT , - `db_schema` TEXT , - `s2_sql` TEXT , - `status` char(10) , - `llm_review` char(10) , - `llm_comment` TEXT, - `human_review` char(10) , - `human_comment` TEXT , - `created_at` datetime NOT NULL , + `question` varchar(655) COMMENT '用户问题' , + `agent_id` INT COMMENT '助理ID' , + `db_schema` TEXT COMMENT 'Schema映射' , + `s2_sql` TEXT COMMENT '大模型解析SQL' , + `status` char(10) COMMENT '状态' , + `llm_review` char(10) COMMENT '大模型评估结果' , + `llm_comment` TEXT COMMENT '大模型评估意见' , + `human_review` char(10) COMMENT '管理员评估结果', + `human_comment` TEXT COMMENT '管理员评估意见', + `created_at` datetime NOT NULL , `updated_at` datetime NOT NULL , `created_by` varchar(100) NOT NULL , `updated_by` varchar(100) NOT NULL ,