[improvement][chat]Introduce ChatMemory to delegate ChatMemoryDO.

This commit is contained in:
jerryjzhang
2024-12-02 01:31:37 +08:00
parent ba4d92e11c
commit 7de3662259
9 changed files with 134 additions and 68 deletions

View File

@@ -1,12 +1,14 @@
package com.tencent.supersonic.chat.api.pojo.request;
import javax.validation.constraints.NotNull;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.Builder;
import lombok.Data;
import javax.validation.constraints.NotNull;
@Data
@Builder
public class ChatMemoryUpdateReq {
@NotNull(message = "id不可为空")

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatQueryExecutor {
Text2SQLExemplar.class);
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemoryDO.builder()
memoryService.createMemory(ChatMemory.builder()
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())

View File

@@ -1,9 +1,10 @@
package com.tencent.supersonic.chat.server.memory;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.ChatApp;
@@ -66,7 +67,7 @@ public class MemoryReviewTask {
}
ChatMemoryFilter chatMemoryFilter =
ChatMemoryFilter.builder().agentId(agent.getId()).build();
memoryService.getMemories(chatMemoryFilter).stream().forEach(memory -> {
memoryService.getMemories(chatMemoryFilter).forEach(memory -> {
try {
processMemory(memory, agent);
} catch (Exception e) {
@@ -77,23 +78,19 @@ public class MemoryReviewTask {
}
}
private void processMemory(ChatMemoryDO m, Agent agent) {
private void processMemory(ChatMemory m, Agent agent) {
if (Objects.isNull(agent)) {
log.warn("Agent id {} not found or memory review disabled", m.getAgentId());
return;
}
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
// if either LLM or human has reviewed, just return
if (Objects.nonNull(m.getLlmReviewRet()) || Objects.nonNull(m.getHumanReviewRet())) {
return;
}
// 如果大模型已经评估过,则不再评估
if (Objects.nonNull(m.getLlmReviewRet())) {
// directly enable memory if the LLM determines it positive
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
memoryService.enableMemory(m);
}
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
}
@@ -112,19 +109,19 @@ public class MemoryReviewTask {
}
}
private String createPromptString(ChatMemoryDO m, String promptTemplate) {
private String createPromptString(ChatMemory m, String promptTemplate) {
return String.format(promptTemplate, m.getQuestion(), m.getDbSchema(), m.getSideInfo(),
m.getS2sql());
}
private void processResponse(String response, ChatMemoryDO m) {
private void processResponse(String response, ChatMemory m) {
Matcher matcher = OUTPUT_PATTERN.matcher(response);
if (matcher.find()) {
m.setLlmReviewRet(MemoryReviewResult.getMemoryReviewResult(matcher.group(1)));
m.setLlmReviewCmt(matcher.group(2));
// directly enable memory if the LLM determines it positive
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
memoryService.enableMemory(m);
m.setStatus(MemoryStatus.ENABLED);
}
memoryService.updateMemory(m);
}

View File

@@ -4,17 +4,17 @@ import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
import lombok.NoArgsConstructor;
import java.util.Date;
@Data
@Builder
@ToString
@NoArgsConstructor
@AllArgsConstructor
@TableName("s2_chat_memory")
public class ChatMemoryDO {
@TableId(type = IdType.AUTO)
@@ -36,16 +36,16 @@ public class ChatMemoryDO {
private String s2sql;
@TableField("status")
private MemoryStatus status;
private String status;
@TableField("llm_review")
private MemoryReviewResult llmReviewRet;
private String llmReviewRet;
@TableField("llm_comment")
private String llmReviewCmt;
@TableField("human_review")
private MemoryReviewResult humanReviewRet;
private String humanReviewRet;
@TableField("human_comment")
private String humanReviewCmt;

View File

@@ -0,0 +1,48 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.util.Date;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
@ToString
public class ChatMemory {
private Long id;
private Integer agentId;
private String question;
private String sideInfo;
private String dbSchema;
private String s2sql;
private MemoryStatus status;
private MemoryReviewResult llmReviewRet;
private String llmReviewCmt;
private MemoryReviewResult humanReviewRet;
private String humanReviewCmt;
private String createdBy;
private Date createdAt;
private String updatedBy;
private Date updatedAt;
}

View File

@@ -1,15 +1,12 @@
package com.tencent.supersonic.chat.server.rest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryCreateReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
@@ -19,6 +16,8 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Date;
@RestController
@@ -32,7 +31,7 @@ public class MemoryController {
public Boolean createMemory(@RequestBody ChatMemoryCreateReq chatMemoryCreateReq,
HttpServletRequest request, HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
memoryService.createMemory(ChatMemoryDO.builder().agentId(chatMemoryCreateReq.getAgentId())
memoryService.createMemory(ChatMemory.builder().agentId(chatMemoryCreateReq.getAgentId())
.s2sql(chatMemoryCreateReq.getS2sql()).question(chatMemoryCreateReq.getQuestion())
.dbSchema(chatMemoryCreateReq.getDbSchema()).status(chatMemoryCreateReq.getStatus())
.humanReviewRet(MemoryReviewResult.POSITIVE).createdBy(user.getName())
@@ -49,7 +48,7 @@ public class MemoryController {
}
@RequestMapping("/pageMemories")
public PageInfo<ChatMemoryDO> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
public PageInfo<ChatMemory> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
return memoryService.pageMemories(pageMemoryReq);
}

View File

@@ -4,27 +4,22 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.common.pojo.User;
import java.util.List;
public interface MemoryService {
void createMemory(ChatMemoryDO memory);
void createMemory(ChatMemory memory);
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
void updateMemory(ChatMemoryDO memory);
void enableMemory(ChatMemoryDO memory);
void disableMemory(ChatMemoryDO memory);
void updateMemory(ChatMemory memory);
void batchDelete(List<Long> ids);
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq);
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);
List<ChatMemory> getMemories(ChatMemoryFilter chatMemoryFilter);
List<ChatMemoryDO> getMemoriesForLlmReview();
}

View File

@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.VisualConfig;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -121,7 +121,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
ChatMemoryFilter chatMemoryFilter =
ChatMemoryFilter.builder().agentId(agent.getId()).questions(examples).build();
List<String> memoriesExisted = memoryService.getMemories(chatMemoryFilter).stream()
.map(ChatMemoryDO::getQuestion).collect(Collectors.toList());
.map(ChatMemory::getQuestion).collect(Collectors.toList());
for (String example : examples) {
if (memoriesExisted.contains(example)) {
continue;

View File

@@ -3,12 +3,14 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
@@ -16,12 +18,15 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.BeanMapper;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Service
public class MemoryServiceImpl implements MemoryService {
@@ -36,20 +41,21 @@ public class MemoryServiceImpl implements MemoryService {
private EmbeddingConfig embeddingConfig;
@Override
public void createMemory(ChatMemoryDO memory) {
public void createMemory(ChatMemory memory) {
// if an existing enabled memory has the same question, just skip
List<ChatMemoryDO> memories =
List<ChatMemory> memories =
getMemories(ChatMemoryFilter.builder().agentId(memory.getAgentId())
.question(memory.getQuestion()).status(MemoryStatus.ENABLED).build());
if (memories.size() == 0) {
chatMemoryRepository.createMemory(memory);
if (memories.isEmpty()) {
ChatMemoryDO memoryDO = getMemoryDO(memory);
chatMemoryRepository.createMemory(memoryDO);
}
}
@Override
public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) {
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
boolean hadEnabled = MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim());
chatMemoryDO.setUpdatedBy(user.getName());
chatMemoryDO.setUpdatedAt(new Date());
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
@@ -58,12 +64,12 @@ public class MemoryServiceImpl implements MemoryService {
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
disableMemory(chatMemoryDO);
}
updateMemory(chatMemoryDO);
chatMemoryRepository.updateMemory(chatMemoryDO);
}
@Override
public void updateMemory(ChatMemoryDO memory) {
chatMemoryRepository.updateMemory(memory);
public void updateMemory(ChatMemory memory) {
chatMemoryRepository.updateMemory(getMemoryDO(memory));
}
@Override
@@ -72,7 +78,7 @@ public class MemoryServiceImpl implements MemoryService {
}
@Override
public PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq) {
public PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq) {
ChatMemoryFilter chatMemoryFilter = pageMemoryReq.getChatMemoryFilter();
chatMemoryFilter.setSort(pageMemoryReq.getSort());
chatMemoryFilter.setOrderCondition(pageMemoryReq.getOrderCondition());
@@ -81,7 +87,7 @@ public class MemoryServiceImpl implements MemoryService {
}
@Override
public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) {
public List<ChatMemory> getMemories(ChatMemoryFilter chatMemoryFilter) {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
if (chatMemoryFilter.getAgentId() != null) {
queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryFilter.getAgentId());
@@ -109,32 +115,51 @@ public class MemoryServiceImpl implements MemoryService {
queryWrapper.orderBy(true, chatMemoryFilter.isAsc(),
chatMemoryFilter.getOrderCondition());
}
return chatMemoryRepository.getMemories(queryWrapper);
List<ChatMemoryDO> chatMemoryDOS = chatMemoryRepository.getMemories(queryWrapper);
return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList());
}
@Override
public List<ChatMemoryDO> getMemoriesForLlmReview() {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
.isNull(ChatMemoryDO::getLlmReviewRet);
return chatMemoryRepository.getMemories(queryWrapper);
}
@Override
public void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED);
private void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED.toString());
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion())
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
.sql(memory.getS2sql()).build());
}
@Override
public void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED);
private void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED.toString());
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion())
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
.sql(memory.getS2sql()).build());
}
private ChatMemoryDO getMemoryDO(ChatMemory memory) {
ChatMemoryDO memoryDO = new ChatMemoryDO();
BeanUtils.copyProperties(memory, memoryDO);
memoryDO.setStatus(memory.getStatus().toString().trim());
if (Objects.nonNull(memory.getHumanReviewRet())) {
memoryDO.setHumanReviewRet(memory.getHumanReviewRet().toString().trim());
}
if (Objects.nonNull(memory.getLlmReviewRet())) {
memoryDO.setLlmReviewRet(memory.getLlmReviewRet().toString().trim());
}
return memoryDO;
}
private ChatMemory getMemory(ChatMemoryDO memoryDO) {
ChatMemory memory = new ChatMemory();
BeanUtils.copyProperties(memoryDO, memory);
memory.setStatus(MemoryStatus.valueOf(memoryDO.getStatus().trim()));
if (Objects.nonNull(memoryDO.getHumanReviewRet())) {
memory.setHumanReviewRet(MemoryReviewResult.valueOf(memoryDO.getHumanReviewRet().trim()));
}
if (Objects.nonNull(memoryDO.getLlmReviewRet())) {
memory.setLlmReviewRet(MemoryReviewResult.valueOf(memoryDO.getLlmReviewRet().trim()));
}
return memory;
}
}