diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java index 39a7f3c28..7f70721bc 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java @@ -1,12 +1,14 @@ package com.tencent.supersonic.chat.api.pojo.request; -import javax.validation.constraints.NotNull; - import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; +import lombok.Builder; import lombok.Data; +import javax.validation.constraints.NotNull; + @Data +@Builder public class ChatMemoryUpdateReq { @NotNull(message = "id不可为空") diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index a77d8c03f..2c3644b3a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.executor; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.pojo.ChatContext; +import com.tencent.supersonic.chat.server.pojo.ChatMemory; import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.service.ChatContextService; import com.tencent.supersonic.chat.server.service.MemoryService; @@ -44,7 +44,7 @@ public class SqlExecutor implements ChatQueryExecutor { Text2SQLExemplar.class); MemoryService memoryService = ContextUtils.getBean(MemoryService.class); - memoryService.createMemory(ChatMemoryDO.builder() + memoryService.createMemory(ChatMemory.builder() .agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING) .question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo()) .dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql()) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index 09889e489..30c81460a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -1,9 +1,10 @@ package com.tencent.supersonic.chat.server.memory; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import com.tencent.supersonic.chat.server.pojo.ChatMemory; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.pojo.ChatApp; @@ -66,7 +67,7 @@ public class MemoryReviewTask { } ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().agentId(agent.getId()).build(); - memoryService.getMemories(chatMemoryFilter).stream().forEach(memory -> { + memoryService.getMemories(chatMemoryFilter).forEach(memory -> { try { processMemory(memory, agent); } catch (Exception e) { @@ -77,23 +78,19 @@ public class MemoryReviewTask { } } - private void processMemory(ChatMemoryDO m, Agent agent) { + private void processMemory(ChatMemory m, Agent agent) { if (Objects.isNull(agent)) { log.warn("Agent id {} not found or memory review disabled", m.getAgentId()); return; } - ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY); - if (Objects.isNull(chatApp) || !chatApp.isEnable()) { + // if either LLM or human has reviewed, just return + if (Objects.nonNull(m.getLlmReviewRet()) || Objects.nonNull(m.getHumanReviewRet())) { return; } - // 如果大模型已经评估过,则不再评估 - if (Objects.nonNull(m.getLlmReviewRet())) { - // directly enable memory if the LLM determines it positive - if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) { - memoryService.enableMemory(m); - } + ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY); + if (Objects.isNull(chatApp) || !chatApp.isEnable()) { return; } @@ -112,19 +109,19 @@ public class MemoryReviewTask { } } - private String createPromptString(ChatMemoryDO m, String promptTemplate) { + private String createPromptString(ChatMemory m, String promptTemplate) { return String.format(promptTemplate, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), m.getS2sql()); } - private void processResponse(String response, ChatMemoryDO m) { + private void processResponse(String response, ChatMemory m) { Matcher matcher = OUTPUT_PATTERN.matcher(response); if (matcher.find()) { m.setLlmReviewRet(MemoryReviewResult.getMemoryReviewResult(matcher.group(1))); m.setLlmReviewCmt(matcher.group(2)); // directly enable memory if the LLM determines it positive if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) { - memoryService.enableMemory(m); + m.setStatus(MemoryStatus.ENABLED); } memoryService.updateMemory(m); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java index d03ca14b8..7622f53fd 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java @@ -4,17 +4,17 @@ import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; -import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; -import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; -import lombok.ToString; +import lombok.NoArgsConstructor; import java.util.Date; @Data @Builder -@ToString +@NoArgsConstructor +@AllArgsConstructor @TableName("s2_chat_memory") public class ChatMemoryDO { @TableId(type = IdType.AUTO) @@ -36,16 +36,16 @@ public class ChatMemoryDO { private String s2sql; @TableField("status") - private MemoryStatus status; + private String status; @TableField("llm_review") - private MemoryReviewResult llmReviewRet; + private String llmReviewRet; @TableField("llm_comment") private String llmReviewCmt; @TableField("human_review") - private MemoryReviewResult humanReviewRet; + private String humanReviewRet; @TableField("human_comment") private String humanReviewCmt; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatMemory.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatMemory.java new file mode 100644 index 000000000..16227be13 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatMemory.java @@ -0,0 +1,48 @@ +package com.tencent.supersonic.chat.server.pojo; + +import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.ToString; + +import java.util.Date; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +@ToString +public class ChatMemory { + private Long id; + + private Integer agentId; + + private String question; + + private String sideInfo; + + private String dbSchema; + + private String s2sql; + + private MemoryStatus status; + + private MemoryReviewResult llmReviewRet; + + private String llmReviewCmt; + + private MemoryReviewResult humanReviewRet; + + private String humanReviewCmt; + + private String createdBy; + + private Date createdAt; + + private String updatedBy; + + private Date updatedAt; +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java index 10d4bdea6..f8b13626a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java @@ -1,15 +1,12 @@ package com.tencent.supersonic.chat.server.rest; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryCreateReq; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import com.tencent.supersonic.chat.server.pojo.ChatMemory; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; @@ -19,6 +16,8 @@ import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import java.util.Date; @RestController @@ -32,7 +31,7 @@ public class MemoryController { public Boolean createMemory(@RequestBody ChatMemoryCreateReq chatMemoryCreateReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); - memoryService.createMemory(ChatMemoryDO.builder().agentId(chatMemoryCreateReq.getAgentId()) + memoryService.createMemory(ChatMemory.builder().agentId(chatMemoryCreateReq.getAgentId()) .s2sql(chatMemoryCreateReq.getS2sql()).question(chatMemoryCreateReq.getQuestion()) .dbSchema(chatMemoryCreateReq.getDbSchema()).status(chatMemoryCreateReq.getStatus()) .humanReviewRet(MemoryReviewResult.POSITIVE).createdBy(user.getName()) @@ -49,7 +48,7 @@ public class MemoryController { } @RequestMapping("/pageMemories") - public PageInfo pageMemories(@RequestBody PageMemoryReq pageMemoryReq) { + public PageInfo pageMemories(@RequestBody PageMemoryReq pageMemoryReq) { return memoryService.pageMemories(pageMemoryReq); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java index 6189ae386..3747f7fd1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java @@ -4,27 +4,22 @@ import com.github.pagehelper.PageInfo; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import com.tencent.supersonic.chat.server.pojo.ChatMemory; import com.tencent.supersonic.common.pojo.User; import java.util.List; public interface MemoryService { - void createMemory(ChatMemoryDO memory); + void createMemory(ChatMemory memory); void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user); - void updateMemory(ChatMemoryDO memory); - - void enableMemory(ChatMemoryDO memory); - - void disableMemory(ChatMemoryDO memory); + void updateMemory(ChatMemory memory); void batchDelete(List ids); - PageInfo pageMemories(PageMemoryReq pageMemoryReq); + PageInfo pageMemories(PageMemoryReq pageMemoryReq); - List getMemories(ChatMemoryFilter chatMemoryFilter); + List getMemories(ChatMemoryFilter chatMemoryFilter); - List getMemoriesForLlmReview(); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 591d3a21b..e8e8ea012 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.VisualConfig; import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; -import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper; +import com.tencent.supersonic.chat.server.pojo.ChatMemory; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.MemoryService; @@ -121,7 +121,7 @@ public class AgentServiceImpl extends ServiceImpl implem ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().agentId(agent.getId()).questions(examples).build(); List memoriesExisted = memoryService.getMemories(chatMemoryFilter).stream() - .map(ChatMemoryDO::getQuestion).collect(Collectors.toList()); + .map(ChatMemory::getQuestion).collect(Collectors.toList()); for (String example : examples) { if (memoriesExisted.contains(example)) { continue; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index 4ab32b8dc..7369d50dd 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -3,12 +3,14 @@ package com.tencent.supersonic.chat.server.service.impl; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageInfo; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; +import com.tencent.supersonic.chat.server.pojo.ChatMemory; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; @@ -16,12 +18,15 @@ import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.util.BeanMapper; import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; import java.util.Date; import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; @Service public class MemoryServiceImpl implements MemoryService { @@ -36,20 +41,21 @@ public class MemoryServiceImpl implements MemoryService { private EmbeddingConfig embeddingConfig; @Override - public void createMemory(ChatMemoryDO memory) { + public void createMemory(ChatMemory memory) { // if an existing enabled memory has the same question, just skip - List memories = + List memories = getMemories(ChatMemoryFilter.builder().agentId(memory.getAgentId()) .question(memory.getQuestion()).status(MemoryStatus.ENABLED).build()); - if (memories.size() == 0) { - chatMemoryRepository.createMemory(memory); + if (memories.isEmpty()) { + ChatMemoryDO memoryDO = getMemoryDO(memory); + chatMemoryRepository.createMemory(memoryDO); } } @Override public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) { ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId()); - boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus()); + boolean hadEnabled = MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim()); chatMemoryDO.setUpdatedBy(user.getName()); chatMemoryDO.setUpdatedAt(new Date()); BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO); @@ -58,12 +64,12 @@ public class MemoryServiceImpl implements MemoryService { } else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) { disableMemory(chatMemoryDO); } - updateMemory(chatMemoryDO); + chatMemoryRepository.updateMemory(chatMemoryDO); } @Override - public void updateMemory(ChatMemoryDO memory) { - chatMemoryRepository.updateMemory(memory); + public void updateMemory(ChatMemory memory) { + chatMemoryRepository.updateMemory(getMemoryDO(memory)); } @Override @@ -72,7 +78,7 @@ public class MemoryServiceImpl implements MemoryService { } @Override - public PageInfo pageMemories(PageMemoryReq pageMemoryReq) { + public PageInfo pageMemories(PageMemoryReq pageMemoryReq) { ChatMemoryFilter chatMemoryFilter = pageMemoryReq.getChatMemoryFilter(); chatMemoryFilter.setSort(pageMemoryReq.getSort()); chatMemoryFilter.setOrderCondition(pageMemoryReq.getOrderCondition()); @@ -81,7 +87,7 @@ public class MemoryServiceImpl implements MemoryService { } @Override - public List getMemories(ChatMemoryFilter chatMemoryFilter) { + public List getMemories(ChatMemoryFilter chatMemoryFilter) { QueryWrapper queryWrapper = new QueryWrapper<>(); if (chatMemoryFilter.getAgentId() != null) { queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryFilter.getAgentId()); @@ -109,32 +115,51 @@ public class MemoryServiceImpl implements MemoryService { queryWrapper.orderBy(true, chatMemoryFilter.isAsc(), chatMemoryFilter.getOrderCondition()); } - return chatMemoryRepository.getMemories(queryWrapper); + List chatMemoryDOS = chatMemoryRepository.getMemories(queryWrapper); + return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList()); } - @Override - public List getMemoriesForLlmReview() { - QueryWrapper queryWrapper = new QueryWrapper<>(); - queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING) - .isNull(ChatMemoryDO::getLlmReviewRet); - return chatMemoryRepository.getMemories(queryWrapper); - } - - @Override - public void enableMemory(ChatMemoryDO memory) { - memory.setStatus(MemoryStatus.ENABLED); + private void enableMemory(ChatMemoryDO memory) { + memory.setStatus(MemoryStatus.ENABLED.toString()); exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), Text2SQLExemplar.builder().question(memory.getQuestion()) .sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema()) .sql(memory.getS2sql()).build()); } - @Override - public void disableMemory(ChatMemoryDO memory) { - memory.setStatus(MemoryStatus.DISABLED); + private void disableMemory(ChatMemoryDO memory) { + memory.setStatus(MemoryStatus.DISABLED.toString()); exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), Text2SQLExemplar.builder().question(memory.getQuestion()) .sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema()) .sql(memory.getS2sql()).build()); } + + private ChatMemoryDO getMemoryDO(ChatMemory memory) { + ChatMemoryDO memoryDO = new ChatMemoryDO(); + BeanUtils.copyProperties(memory, memoryDO); + memoryDO.setStatus(memory.getStatus().toString().trim()); + if (Objects.nonNull(memory.getHumanReviewRet())) { + memoryDO.setHumanReviewRet(memory.getHumanReviewRet().toString().trim()); + } + if (Objects.nonNull(memory.getLlmReviewRet())) { + memoryDO.setLlmReviewRet(memory.getLlmReviewRet().toString().trim()); + } + + return memoryDO; + } + + private ChatMemory getMemory(ChatMemoryDO memoryDO) { + ChatMemory memory = new ChatMemory(); + BeanUtils.copyProperties(memoryDO, memory); + memory.setStatus(MemoryStatus.valueOf(memoryDO.getStatus().trim())); + if (Objects.nonNull(memoryDO.getHumanReviewRet())) { + memory.setHumanReviewRet(MemoryReviewResult.valueOf(memoryDO.getHumanReviewRet().trim())); + } + if (Objects.nonNull(memoryDO.getLlmReviewRet())) { + memory.setLlmReviewRet(MemoryReviewResult.valueOf(memoryDO.getLlmReviewRet().trim())); + } + return memory; + } + }