From 89c63cd44d9532fee958243688e80233f458a5c9 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Thu, 27 Jun 2024 17:12:35 +0800 Subject: [PATCH] (improvement)(Chat) Put agent examples into ChatMemory automatically (#1249) Co-authored-by: lxwcodemonkey --- .../api/pojo/enums/MemoryReviewResult.java | 9 +++ .../chat/api/pojo/enums/MemoryStatus.java | 10 +++ .../api/pojo/request/ChatMemoryFilter.java | 28 ++++++++ .../chat/api/pojo/request/PageMemoryReq.java | 12 ++++ .../chat/server/executor/SqlExecutor.java | 3 +- .../AgentExample2MemoryTransformer.java | 49 ++++++++++++++ .../chat/server/memory/MemoryReviewTask.java | 8 +-- .../persistence/dataobject/ChatMemoryDO.java | 18 ++---- .../repository/ChatMemoryRepository.java | 3 +- .../impl/ChatMemoryRepositoryImpl.java | 4 +- .../chat/server/rest/MemoryController.java | 54 ++++++++++++++++ .../chat/server/service/ChatService.java | 2 + .../chat/server/service/MemoryService.java | 9 ++- .../server/service/impl/AgentServiceImpl.java | 9 ++- .../server/service/impl/ChatServiceImpl.java | 13 ++++ .../service/impl/MemoryServiceImpl.java | 64 ++++++++++++++++--- .../common/service/ExemplarService.java | 2 + .../service/impl/ExemplarServiceImpl.java | 8 +++ .../resources/config.update/sql-update.sql | 19 ++++++ .../src/main/resources/db/schema-mysql.sql | 4 +- 20 files changed, 292 insertions(+), 36 deletions(-) create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryReviewResult.java create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryStatus.java create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryFilter.java create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/PageMemoryReq.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/AgentExample2MemoryTransformer.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryReviewResult.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryReviewResult.java new file mode 100644 index 000000000..df706ab76 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryReviewResult.java @@ -0,0 +1,9 @@ +package com.tencent.supersonic.chat.api.pojo.enums; + + +public enum MemoryReviewResult { + + POSITIVE, + NEGATIVE + +} diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryStatus.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryStatus.java new file mode 100644 index 000000000..b00f08b96 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryStatus.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.chat.api.pojo.enums; + + +public enum MemoryStatus { + + PENDING, + ENABLED, + DISABLED; + +} diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryFilter.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryFilter.java new file mode 100644 index 000000000..00ec855b4 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryFilter.java @@ -0,0 +1,28 @@ +package com.tencent.supersonic.chat.api.pojo.request; + +import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class ChatMemoryFilter { + + private String question; + + private List questions; + + private MemoryStatus status; + + private MemoryReviewResult llmReviewRet; + + private MemoryReviewResult humanReviewRet; + +} diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/PageMemoryReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/PageMemoryReq.java new file mode 100644 index 000000000..fde9aa2d7 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/PageMemoryReq.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.chat.api.pojo.request; + +import com.tencent.supersonic.common.pojo.PageBaseReq; +import lombok.Data; + + +@Data +public class PageMemoryReq extends PageBaseReq { + + private ChatMemoryFilter chatMemoryFilter = new ChatMemoryFilter(); + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index bb53830ca..a9c9fb4e6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.server.executor; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; import com.tencent.supersonic.chat.server.service.MemoryService; @@ -32,7 +33,7 @@ public class SqlExecutor implements ChatExecutor { MemoryService memoryService = ContextUtils.getBean(MemoryService.class); memoryService.createMemory(ChatMemoryDO.builder() .agentId(chatExecuteContext.getAgentId()) - .status(ChatMemoryDO.Status.PENDING) + .status(MemoryStatus.PENDING) .question(chatExecuteContext.getQueryText()) .s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL()) .dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo())) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/AgentExample2MemoryTransformer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/AgentExample2MemoryTransformer.java new file mode 100644 index 000000000..7693922b9 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/AgentExample2MemoryTransformer.java @@ -0,0 +1,49 @@ +package com.tencent.supersonic.chat.server.memory; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; +import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import com.tencent.supersonic.chat.server.service.ChatService; +import com.tencent.supersonic.chat.server.service.MemoryService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Component; +import java.util.List; +import java.util.stream.Collectors; + +@Component +@Slf4j +public class AgentExample2MemoryTransformer { + + @Autowired + private ChatService chatService; + + @Autowired + private MemoryService memoryService; + + @Async + public void transform(Agent agent) { + if (!agent.containsLLMParserTool() || agent.getLlmConfig() == null) { + return; + } + List examples = agent.getExamples(); + ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().questions(examples).build(); + List memoriesExisted = memoryService.getMemories(chatMemoryFilter) + .stream().map(ChatMemoryDO::getQuestion).collect(Collectors.toList()); + for (String example : examples) { + if (memoriesExisted.contains(example)) { + continue; + } + ChatParseReq chatParseReq = new ChatParseReq(); + chatParseReq.setAgentId(agent.getId()); + chatParseReq.setQueryText(example); + chatParseReq.setUser(User.getFakeUser()); + chatParseReq.setChatId(-1); + chatService.parseAndExecute(chatParseReq); + } + } + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index 6f3c92c18..94f382cf0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -1,8 +1,7 @@ package com.tencent.supersonic.chat.server.memory; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.util.S2ChatModelProvider; @@ -48,8 +47,7 @@ public class MemoryReviewTask { @Scheduled(fixedDelay = 60 * 1000) public void review() { - memoryService.getMemories().stream() - .filter(c -> c.getStatus() == ChatMemoryDO.Status.PENDING) + memoryService.getMemoriesForLlmReview().stream() .forEach(m -> { Agent chatAgent = agentService.getAgent(m.getAgentId()); String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql()); @@ -62,7 +60,7 @@ public class MemoryReviewTask { Matcher matcher = OUTPUT_PATTERN.matcher(response); if (matcher.find()) { - m.setLlmReviewRet(ReviewResult.valueOf(matcher.group(1))); + m.setLlmReviewRet(MemoryReviewResult.valueOf(matcher.group(1))); m.setLlmReviewCmt(matcher.group(2)); memoryService.updateMemory(m); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java index 9276b9ae8..fea7ee38f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java @@ -4,6 +4,8 @@ import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import lombok.Builder; import lombok.Data; import lombok.ToString; @@ -31,16 +33,16 @@ public class ChatMemoryDO { private String s2sql; @TableField("status") - private Status status; + private MemoryStatus status; @TableField("llm_review") - private ReviewResult llmReviewRet; + private MemoryReviewResult llmReviewRet; @TableField("llm_comment") private String llmReviewCmt; @TableField("human_review") - private ReviewResult humanReviewRet; + private MemoryReviewResult humanReviewRet; @TableField("human_comment") private String humanReviewCmt; @@ -57,14 +59,4 @@ public class ChatMemoryDO { @TableField("updated_at") private Date updatedAt; - public enum ReviewResult { - POSITIVE, - NEGATIVE - } - - public enum Status { - PENDING, - ENABLED, - DISABLED; - } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatMemoryRepository.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatMemoryRepository.java index 053699a52..90d9b2a11 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatMemoryRepository.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatMemoryRepository.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.server.persistence.repository; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import java.util.List; @@ -11,5 +12,5 @@ public interface ChatMemoryRepository { ChatMemoryDO getMemory(Long id); - List getMemories(); + List getMemories(QueryWrapper queryWrapper); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatMemoryRepositoryImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatMemoryRepositoryImpl.java index 4e2033b0a..6d71bcd5d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatMemoryRepositoryImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatMemoryRepositoryImpl.java @@ -35,8 +35,8 @@ public class ChatMemoryRepositoryImpl implements ChatMemoryRepository { } @Override - public List getMemories() { - return chatMemoryMapper.selectList(new QueryWrapper<>()); + public List getMemories(QueryWrapper queryWrapper) { + return chatMemoryMapper.selectList(queryWrapper); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java new file mode 100644 index 000000000..55f20a310 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java @@ -0,0 +1,54 @@ +package com.tencent.supersonic.chat.server.rest; + +import com.github.pagehelper.PageInfo; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import com.tencent.supersonic.chat.server.service.MemoryService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Date; + +@RestController +@RequestMapping({"/api/chat/memory"}) +public class MemoryController { + + @Autowired + private MemoryService memoryService; + + @PostMapping("/createMemory") + public Boolean createMemory(@RequestBody ChatMemoryDO memory, + HttpServletRequest request, + HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + memory.setCreatedBy(user.getName()); + memory.setUpdatedBy(user.getName()); + memory.setCreatedAt(new Date()); + memory.setUpdatedAt(new Date()); + memoryService.createMemory(memory); + return true; + } + + @PostMapping("/updateMemory") + public void updateMemory(ChatMemoryDO memory, + HttpServletRequest request, + HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + memory.setUpdatedBy(user.getName()); + memory.setUpdatedAt(new Date()); + memoryService.updateMemory(memory); + } + + @RequestMapping("/pageMemories") + public PageInfo pageMemories(@RequestBody PageMemoryReq pageMemoryReq) { + return memoryService.pageMemories(pageMemoryReq); + } + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatService.java index 4001e9b94..37af953dc 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatService.java @@ -20,6 +20,8 @@ public interface ChatService { QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception; + QueryResult parseAndExecute(ChatParseReq chatParseReq); + Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception; SemanticParseInfo queryContext(Integer chatId); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java index 4b1965145..3e445b01c 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java @@ -1,5 +1,8 @@ package com.tencent.supersonic.chat.server.service; +import com.github.pagehelper.PageInfo; +import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; +import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import java.util.List; @@ -9,5 +12,9 @@ public interface MemoryService { void updateMemory(ChatMemoryDO memory); - List getMemories(); + PageInfo pageMemories(PageMemoryReq pageMemoryReq); + + List getMemories(ChatMemoryFilter chatMemoryFilter); + + List getMemoriesForLlmReview(); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 9fa5cd066..6dba0686f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -4,13 +4,15 @@ import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; +import com.tencent.supersonic.chat.server.memory.AgentExample2MemoryTransformer; import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper; import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.config.LLMConfig; import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.List; import java.util.stream.Collectors; @@ -19,6 +21,9 @@ import java.util.stream.Collectors; public class AgentServiceImpl extends ServiceImpl implements AgentService { + @Autowired + private AgentExample2MemoryTransformer agentExample2MemoryTransformer; + @Override public List getAgents() { return getAgentDOList().stream() @@ -30,6 +35,7 @@ public class AgentServiceImpl extends ServiceImpl agent.createdBy(user.getName()); AgentDO agentDO = convert(agent); save(agentDO); + agentExample2MemoryTransformer.transform(agent); return agentDO.getId(); } @@ -37,6 +43,7 @@ public class AgentServiceImpl extends ServiceImpl public void updateAgent(Agent agent, User user) { agent.updatedBy(user.getName()); updateById(convert(agent)); + agentExample2MemoryTransformer.transform(agent); } @Override diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java index 3d194bad8..cc2a89ea6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java @@ -95,6 +95,19 @@ public class ChatServiceImpl implements ChatService { return queryResult; } + @Override + public QueryResult parseAndExecute(ChatParseReq chatParseReq) { + ParseResp parseResp = performParsing(chatParseReq); + ChatExecuteReq chatExecuteReq = new ChatExecuteReq(); + chatExecuteReq.setQueryId(parseResp.getQueryId()); + chatExecuteReq.setChatId(chatParseReq.getChatId()); + chatExecuteReq.setUser(chatParseReq.getUser()); + chatExecuteReq.setAgentId(chatParseReq.getAgentId()); + chatExecuteReq.setQueryText(chatParseReq.getQueryText()); + chatExecuteReq.setParseId(parseResp.getSelectedParses().get(0).getId()); + return performExecution(chatExecuteReq); + } + private ChatParseContext buildParseContext(ChatParseReq chatParseReq) { ChatParseContext chatParseContext = new ChatParseContext(); BeanMapper.mapper(chatParseReq, chatParseContext); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index 5032f9dea..477fe3018 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -1,14 +1,20 @@ package com.tencent.supersonic.chat.server.service.impl; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.github.pagehelper.PageHelper; +import com.github.pagehelper.PageInfo; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; +import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; +import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.Status; import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.service.ExemplarService; +import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; import java.util.List; @@ -23,24 +29,53 @@ public class MemoryServiceImpl implements MemoryService { @Override public void createMemory(ChatMemoryDO memory) { - if (ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) { - enableMemory(memory); - } chatMemoryRepository.createMemory(memory); } @Override public void updateMemory(ChatMemoryDO memory) { - if (!ChatMemoryDO.Status.ENABLED.equals(memory.getStatus()) - && ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) { + if (MemoryStatus.ENABLED.equals(memory.getStatus())) { enableMemory(memory); + } else if (MemoryStatus.DISABLED.equals(memory.getStatus())) { + disableMemory(memory); } chatMemoryRepository.updateMemory(memory); } @Override - public List getMemories() { - return chatMemoryRepository.getMemories(); + public PageInfo pageMemories(PageMemoryReq pageMemoryReq) { + return PageHelper.startPage(pageMemoryReq.getCurrent(), + pageMemoryReq.getPageSize()) + .doSelectPageInfo(() -> getMemories(pageMemoryReq.getChatMemoryFilter())); + } + + @Override + public List getMemories(ChatMemoryFilter chatMemoryFilter) { + QueryWrapper queryWrapper = new QueryWrapper<>(); + if (StringUtils.isNotBlank(chatMemoryFilter.getQuestion())) { + queryWrapper.lambda().like(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestion()); + } + if (!CollectionUtils.isEmpty(chatMemoryFilter.getQuestions())) { + queryWrapper.lambda().in(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestions()); + } + if (chatMemoryFilter.getStatus() != null) { + queryWrapper.lambda().eq(ChatMemoryDO::getStatus, chatMemoryFilter.getStatus()); + } + if (chatMemoryFilter.getHumanReviewRet() != null) { + queryWrapper.lambda().eq(ChatMemoryDO::getHumanReviewRet, chatMemoryFilter.getHumanReviewRet()); + } + if (chatMemoryFilter.getLlmReviewRet() != null) { + queryWrapper.lambda().eq(ChatMemoryDO::getLlmReviewRet, chatMemoryFilter.getLlmReviewRet()); + } + return chatMemoryRepository.getMemories(queryWrapper); + } + + @Override + public List getMemoriesForLlmReview() { + QueryWrapper queryWrapper = new QueryWrapper<>(); + queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING) + .isNull(ChatMemoryDO::getLlmReviewRet); + return chatMemoryRepository.getMemories(queryWrapper); } private void enableMemory(ChatMemoryDO memory) { @@ -50,6 +85,15 @@ public class MemoryServiceImpl implements MemoryService { .dbSchema(memory.getDbSchema()) .sql(memory.getS2sql()) .build()); - memory.setStatus(Status.ENABLED); } + + private void disableMemory(ChatMemoryDO memory) { + exemplarService.removeExemplar(memory.getAgentId().toString(), + SqlExemplar.builder() + .question(memory.getQuestion()) + .dbSchema(memory.getDbSchema()) + .sql(memory.getS2sql()) + .build()); + } + } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java b/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java index f08839682..e7ae8554b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java @@ -7,6 +7,8 @@ import java.util.List; public interface ExemplarService { void storeExemplar(String collection, SqlExemplar exemplar); + void removeExemplar(String collection, SqlExemplar exemplar); + List recallExemplars(String collection, String query, int num); List recallExemplars(String query, int num); diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java index 635666b43..9d7d23833 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java @@ -49,6 +49,14 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { embeddingService.addQuery(collection, Lists.newArrayList(segment)); } + public void removeExemplar(String collection, SqlExemplar exemplar) { + Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar), + String.class, Object.class)); + TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata); + + embeddingService.deleteQuery(collection, Lists.newArrayList(segment)); + } + public List recallExemplars(String query, int num) { String collection = embeddingConfig.getText2sqlCollectionName(); return recallExemplars(collection, query, num); diff --git a/launchers/standalone/src/main/resources/config.update/sql-update.sql b/launchers/standalone/src/main/resources/config.update/sql-update.sql index ed8e2fd8e..183caa7d3 100644 --- a/launchers/standalone/src/main/resources/config.update/sql-update.sql +++ b/launchers/standalone/src/main/resources/config.update/sql-update.sql @@ -329,3 +329,22 @@ alter table s2_agent add column `visual_config` varchar(2000) COLLATE utf8_unic alter table s2_term add column `related_metrics` varchar(1000) DEFAULT NULL COMMENT '术语关联的指标'; alter table s2_term add column `related_dimensions` varchar(1000) DEFAULT NULL COMMENT '术语关联的维度'; + +--20240627 +CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( + `id` INT NOT NULL AUTO_INCREMENT, + `question` varchar(655) , + `agent_id` INT , + `db_schema` TEXT , + `s2_sql` TEXT , + `status` char(10) , + `llm_review` char(10) , + `llm_comment` TEXT, + `human_review` char(10) , + `human_comment` TEXT , + `created_at` datetime NOT NULL , + `updated_at` datetime NOT NULL , + `created_by` varchar(100) NOT NULL , + `updated_by` varchar(100) NOT NULL , + PRIMARY KEY (`id`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8; \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index f93479420..c04ed2f52 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -147,8 +147,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( `llm_comment` TEXT, `human_review` char(10) , `human_comment` TEXT , - `created_at` TIMESTAMP NOT NULL , - `updated_at` TIMESTAMP NOT NULL , + `created_at` datetime NOT NULL , + `updated_at` datetime NOT NULL , `created_by` varchar(100) NOT NULL , `updated_by` varchar(100) NOT NULL , PRIMARY KEY (`id`)