From f26845e29b8c263c95426c4c2217d241a3b2d204 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sat, 12 Oct 2024 16:10:09 +0800 Subject: [PATCH] [feature][chat]Support creation of chat memory via REST API.#1603 --- .../api/pojo/request/ChatMemoryCreateReq.java | 18 ++++++++++++++++++ .../api/pojo/request/ChatMemoryUpdateReq.java | 3 +-- .../chat/server/parser/NL2SQLParser.java | 8 ++++---- .../chat/server/rest/MemoryController.java | 16 ++++++++++++++++ .../service/impl/ChatQueryServiceImpl.java | 4 ++-- .../server/service/impl/MemoryServiceImpl.java | 6 ++++-- 6 files changed, 45 insertions(+), 10 deletions(-) create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryCreateReq.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryCreateReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryCreateReq.java new file mode 100644 index 000000000..74f4f95e1 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryCreateReq.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.chat.api.pojo.request; + +import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; +import lombok.Data; + +@Data +public class ChatMemoryCreateReq { + + private Integer agentId; + + private String question; + + private String dbSchema; + + private String s2sql; + + private MemoryStatus status; +} diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java index cc76176aa..39a7f3c28 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java @@ -4,11 +4,10 @@ import javax.validation.constraints.NotNull; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; -import com.tencent.supersonic.common.pojo.RecordInfo; import lombok.Data; @Data -public class ChatMemoryUpdateReq extends RecordInfo { +public class ChatMemoryUpdateReq { @NotNull(message = "id不可为空") private Long id; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 33e2e4006..0c83bcbd8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -78,12 +78,12 @@ public class NL2SQLParser implements ChatQueryParser { public NL2SQLParser() { ChatAppManager.register(APP_KEY_MULTI_TURN, - ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION) - .name("多轮对话改写").description("通过大模型根据历史对话来改写本轮对话").enable(false).build()); + ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写") + .description("通过大模型根据历史对话来改写本轮对话").enable(false).build()); ChatAppManager.register(APP_KEY_ERROR_MESSAGE, - ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION) - .name("异常提示改写").description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build()); + ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写") + .description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build()); } @Override diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java index 5a720dfc7..febc87bca 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java @@ -6,6 +6,8 @@ import javax.servlet.http.HttpServletResponse; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; +import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryCreateReq; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; @@ -17,6 +19,8 @@ import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import java.util.Date; + @RestController @RequestMapping({"/api/chat/memory"}) public class MemoryController { @@ -24,6 +28,18 @@ public class MemoryController { @Autowired private MemoryService memoryService; + @PostMapping("/createMemory") + public Boolean createMemory(@RequestBody ChatMemoryCreateReq chatMemoryCreateReq, + HttpServletRequest request, HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + memoryService.createMemory(ChatMemoryDO.builder().agentId(chatMemoryCreateReq.getAgentId()) + .s2sql(chatMemoryCreateReq.getS2sql()).question(chatMemoryCreateReq.getQuestion()) + .dbSchema(chatMemoryCreateReq.getDbSchema()).status(chatMemoryCreateReq.getStatus()) + .humanReviewRet(MemoryReviewResult.POSITIVE).createdBy(user.getName()) + .createdAt(new Date()).build()); + return true; + } + @PostMapping("/updateMemory") public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq, HttpServletRequest request, HttpServletResponse response) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index c824401ad..5edb1b241 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -171,8 +171,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { ParseContext parseContext = new ParseContext(); BeanMapper.mapper(chatParseReq, parseContext); Agent agent = agentService.getAgent(chatParseReq.getAgentId()); - agent.getChatAppConfig().values().forEach(c -> c.setChatModelConfig( - chatModelService.getChatModel(c.getChatModelId()).getConfig())); + agent.getChatAppConfig().values().forEach(c -> c + .setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig())); parseContext.setAgent(agent); return parseContext; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index aadd11c67..52bb070fb 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -20,6 +20,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; +import java.util.Date; import java.util.List; @Service @@ -46,10 +47,11 @@ public class MemoryServiceImpl implements MemoryService { @Override public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) { - chatMemoryUpdateReq.updatedBy(user.getName()); ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId()); - boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus()); + chatMemoryDO.setUpdatedBy(user.getName()); + chatMemoryDO.setUpdatedAt(new Date()); BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO); + boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus()); if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) { enableMemory(chatMemoryDO); } else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {