[feature][chat]Support creation of chat memory via REST API.#1603

This commit is contained in:
jerryjzhang
2024-10-12 16:10:09 +08:00
parent f7da6b8ad1
commit f26845e29b
6 changed files with 45 additions and 10 deletions

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.Data;
@Data
public class ChatMemoryCreateReq {
private Integer agentId;
private String question;
private String dbSchema;
private String s2sql;
private MemoryStatus status;
}

View File

@@ -4,11 +4,10 @@ import javax.validation.constraints.NotNull;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data; import lombok.Data;
@Data @Data
public class ChatMemoryUpdateReq extends RecordInfo { public class ChatMemoryUpdateReq {
@NotNull(message = "id不可为空") @NotNull(message = "id不可为空")
private Long id; private Long id;

View File

@@ -78,12 +78,12 @@ public class NL2SQLParser implements ChatQueryParser {
public NL2SQLParser() { public NL2SQLParser() {
ChatAppManager.register(APP_KEY_MULTI_TURN, ChatAppManager.register(APP_KEY_MULTI_TURN,
ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION) ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写")
.name("多轮对话改写").description("通过大模型根据历史对话来改写本轮对话").enable(false).build()); .description("通过大模型根据历史对话来改写本轮对话").enable(false).build());
ChatAppManager.register(APP_KEY_ERROR_MESSAGE, ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION) ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
.name("异常提示改写").description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build()); .description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build());
} }
@Override @Override

View File

@@ -6,6 +6,8 @@ import javax.servlet.http.HttpServletResponse;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryCreateReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
@@ -17,6 +19,8 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import java.util.Date;
@RestController @RestController
@RequestMapping({"/api/chat/memory"}) @RequestMapping({"/api/chat/memory"})
public class MemoryController { public class MemoryController {
@@ -24,6 +28,18 @@ public class MemoryController {
@Autowired @Autowired
private MemoryService memoryService; private MemoryService memoryService;
@PostMapping("/createMemory")
public Boolean createMemory(@RequestBody ChatMemoryCreateReq chatMemoryCreateReq,
HttpServletRequest request, HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
memoryService.createMemory(ChatMemoryDO.builder().agentId(chatMemoryCreateReq.getAgentId())
.s2sql(chatMemoryCreateReq.getS2sql()).question(chatMemoryCreateReq.getQuestion())
.dbSchema(chatMemoryCreateReq.getDbSchema()).status(chatMemoryCreateReq.getStatus())
.humanReviewRet(MemoryReviewResult.POSITIVE).createdBy(user.getName())
.createdAt(new Date()).build());
return true;
}
@PostMapping("/updateMemory") @PostMapping("/updateMemory")
public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq, public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq,
HttpServletRequest request, HttpServletResponse response) { HttpServletRequest request, HttpServletResponse response) {

View File

@@ -171,8 +171,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParseContext parseContext = new ParseContext(); ParseContext parseContext = new ParseContext();
BeanMapper.mapper(chatParseReq, parseContext); BeanMapper.mapper(chatParseReq, parseContext);
Agent agent = agentService.getAgent(chatParseReq.getAgentId()); Agent agent = agentService.getAgent(chatParseReq.getAgentId());
agent.getChatAppConfig().values().forEach(c -> c.setChatModelConfig( agent.getChatAppConfig().values().forEach(c -> c
chatModelService.getChatModel(c.getChatModelId()).getConfig())); .setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig()));
parseContext.setAgent(agent); parseContext.setAgent(agent);
return parseContext; return parseContext;
} }

View File

@@ -20,6 +20,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.Date;
import java.util.List; import java.util.List;
@Service @Service
@@ -46,10 +47,11 @@ public class MemoryServiceImpl implements MemoryService {
@Override @Override
public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) { public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) {
chatMemoryUpdateReq.updatedBy(user.getName());
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId()); ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus()); chatMemoryDO.setUpdatedBy(user.getName());
chatMemoryDO.setUpdatedAt(new Date());
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO); BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) { if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
enableMemory(chatMemoryDO); enableMemory(chatMemoryDO);
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) { } else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {