(improvement)(Chat) Put agent examples into ChatMemory automatically (#1249)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-06-27 17:12:35 +08:00
committed by GitHub
parent bbd61ac937
commit 89c63cd44d
20 changed files with 292 additions and 36 deletions

View File

@@ -0,0 +1,9 @@
package com.tencent.supersonic.chat.api.pojo.enums;
public enum MemoryReviewResult {
POSITIVE,
NEGATIVE
}

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.chat.api.pojo.enums;
public enum MemoryStatus {
PENDING,
ENABLED,
DISABLED;
}

View File

@@ -0,0 +1,28 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ChatMemoryFilter {
private String question;
private List<String> questions;
private MemoryStatus status;
private MemoryReviewResult llmReviewRet;
private MemoryReviewResult humanReviewRet;
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.PageBaseReq;
import lombok.Data;
@Data
public class PageMemoryReq extends PageBaseReq {
private ChatMemoryFilter chatMemoryFilter = new ChatMemoryFilter();
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -32,7 +33,7 @@ public class SqlExecutor implements ChatExecutor {
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemoryDO.builder()
.agentId(chatExecuteContext.getAgentId())
.status(ChatMemoryDO.Status.PENDING)
.status(MemoryStatus.PENDING)
.question(chatExecuteContext.getQueryText())
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))

View File

@@ -0,0 +1,49 @@
package com.tencent.supersonic.chat.server.memory;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
public class AgentExample2MemoryTransformer {
@Autowired
private ChatService chatService;
@Autowired
private MemoryService memoryService;
@Async
public void transform(Agent agent) {
if (!agent.containsLLMParserTool() || agent.getLlmConfig() == null) {
return;
}
List<String> examples = agent.getExamples();
ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().questions(examples).build();
List<String> memoriesExisted = memoryService.getMemories(chatMemoryFilter)
.stream().map(ChatMemoryDO::getQuestion).collect(Collectors.toList());
for (String example : examples) {
if (memoriesExisted.contains(example)) {
continue;
}
ChatParseReq chatParseReq = new ChatParseReq();
chatParseReq.setAgentId(agent.getId());
chatParseReq.setQueryText(example);
chatParseReq.setUser(User.getFakeUser());
chatParseReq.setChatId(-1);
chatService.parseAndExecute(chatParseReq);
}
}
}

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.chat.server.memory;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
@@ -48,8 +47,7 @@ public class MemoryReviewTask {
@Scheduled(fixedDelay = 60 * 1000)
public void review() {
memoryService.getMemories().stream()
.filter(c -> c.getStatus() == ChatMemoryDO.Status.PENDING)
memoryService.getMemoriesForLlmReview().stream()
.forEach(m -> {
Agent chatAgent = agentService.getAgent(m.getAgentId());
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql());
@@ -62,7 +60,7 @@ public class MemoryReviewTask {
Matcher matcher = OUTPUT_PATTERN.matcher(response);
if (matcher.find()) {
m.setLlmReviewRet(ReviewResult.valueOf(matcher.group(1)));
m.setLlmReviewRet(MemoryReviewResult.valueOf(matcher.group(1)));
m.setLlmReviewCmt(matcher.group(2));
memoryService.updateMemory(m);
}

View File

@@ -4,6 +4,8 @@ import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
@@ -31,16 +33,16 @@ public class ChatMemoryDO {
private String s2sql;
@TableField("status")
private Status status;
private MemoryStatus status;
@TableField("llm_review")
private ReviewResult llmReviewRet;
private MemoryReviewResult llmReviewRet;
@TableField("llm_comment")
private String llmReviewCmt;
@TableField("human_review")
private ReviewResult humanReviewRet;
private MemoryReviewResult humanReviewRet;
@TableField("human_comment")
private String humanReviewCmt;
@@ -57,14 +59,4 @@ public class ChatMemoryDO {
@TableField("updated_at")
private Date updatedAt;
public enum ReviewResult {
POSITIVE,
NEGATIVE
}
public enum Status {
PENDING,
ENABLED,
DISABLED;
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import java.util.List;
@@ -11,5 +12,5 @@ public interface ChatMemoryRepository {
ChatMemoryDO getMemory(Long id);
List<ChatMemoryDO> getMemories();
List<ChatMemoryDO> getMemories(QueryWrapper<ChatMemoryDO> queryWrapper);
}

View File

@@ -35,8 +35,8 @@ public class ChatMemoryRepositoryImpl implements ChatMemoryRepository {
}
@Override
public List<ChatMemoryDO> getMemories() {
return chatMemoryMapper.selectList(new QueryWrapper<>());
public List<ChatMemoryDO> getMemories(QueryWrapper<ChatMemoryDO> queryWrapper) {
return chatMemoryMapper.selectList(queryWrapper);
}
}

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.chat.server.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.service.MemoryService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Date;
@RestController
@RequestMapping({"/api/chat/memory"})
public class MemoryController {
@Autowired
private MemoryService memoryService;
@PostMapping("/createMemory")
public Boolean createMemory(@RequestBody ChatMemoryDO memory,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
memory.setCreatedBy(user.getName());
memory.setUpdatedBy(user.getName());
memory.setCreatedAt(new Date());
memory.setUpdatedAt(new Date());
memoryService.createMemory(memory);
return true;
}
@PostMapping("/updateMemory")
public void updateMemory(ChatMemoryDO memory,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
memory.setUpdatedBy(user.getName());
memory.setUpdatedAt(new Date());
memoryService.updateMemory(memory);
}
@RequestMapping("/pageMemories")
public PageInfo<ChatMemoryDO> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
return memoryService.pageMemories(pageMemoryReq);
}
}

View File

@@ -20,6 +20,8 @@ public interface ChatService {
QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception;
QueryResult parseAndExecute(ChatParseReq chatParseReq);
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
SemanticParseInfo queryContext(Integer chatId);

View File

@@ -1,5 +1,8 @@
package com.tencent.supersonic.chat.server.service;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import java.util.List;
@@ -9,5 +12,9 @@ public interface MemoryService {
void updateMemory(ChatMemoryDO memory);
List<ChatMemoryDO> getMemories();
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);
List<ChatMemoryDO> getMemoriesForLlmReview();
}

View File

@@ -4,13 +4,15 @@ import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.memory.AgentExample2MemoryTransformer;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.config.LLMConfig;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.stream.Collectors;
@@ -19,6 +21,9 @@ import java.util.stream.Collectors;
public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
implements AgentService {
@Autowired
private AgentExample2MemoryTransformer agentExample2MemoryTransformer;
@Override
public List<Agent> getAgents() {
return getAgentDOList().stream()
@@ -30,6 +35,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
agent.createdBy(user.getName());
AgentDO agentDO = convert(agent);
save(agentDO);
agentExample2MemoryTransformer.transform(agent);
return agentDO.getId();
}
@@ -37,6 +43,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
public void updateAgent(Agent agent, User user) {
agent.updatedBy(user.getName());
updateById(convert(agent));
agentExample2MemoryTransformer.transform(agent);
}
@Override

View File

@@ -95,6 +95,19 @@ public class ChatServiceImpl implements ChatService {
return queryResult;
}
@Override
public QueryResult parseAndExecute(ChatParseReq chatParseReq) {
ParseResp parseResp = performParsing(chatParseReq);
ChatExecuteReq chatExecuteReq = new ChatExecuteReq();
chatExecuteReq.setQueryId(parseResp.getQueryId());
chatExecuteReq.setChatId(chatParseReq.getChatId());
chatExecuteReq.setUser(chatParseReq.getUser());
chatExecuteReq.setAgentId(chatParseReq.getAgentId());
chatExecuteReq.setQueryText(chatParseReq.getQueryText());
chatExecuteReq.setParseId(parseResp.getSelectedParses().get(0).getId());
return performExecution(chatExecuteReq);
}
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
ChatParseContext chatParseContext = new ChatParseContext();
BeanMapper.mapper(chatParseReq, chatParseContext);

View File

@@ -1,14 +1,20 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.Status;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.List;
@@ -23,24 +29,53 @@ public class MemoryServiceImpl implements MemoryService {
@Override
public void createMemory(ChatMemoryDO memory) {
if (ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
enableMemory(memory);
}
chatMemoryRepository.createMemory(memory);
}
@Override
public void updateMemory(ChatMemoryDO memory) {
if (!ChatMemoryDO.Status.ENABLED.equals(memory.getStatus())
&& ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
if (MemoryStatus.ENABLED.equals(memory.getStatus())) {
enableMemory(memory);
} else if (MemoryStatus.DISABLED.equals(memory.getStatus())) {
disableMemory(memory);
}
chatMemoryRepository.updateMemory(memory);
}
@Override
public List<ChatMemoryDO> getMemories() {
return chatMemoryRepository.getMemories();
public PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq) {
return PageHelper.startPage(pageMemoryReq.getCurrent(),
pageMemoryReq.getPageSize())
.doSelectPageInfo(() -> getMemories(pageMemoryReq.getChatMemoryFilter()));
}
@Override
public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
if (StringUtils.isNotBlank(chatMemoryFilter.getQuestion())) {
queryWrapper.lambda().like(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestion());
}
if (!CollectionUtils.isEmpty(chatMemoryFilter.getQuestions())) {
queryWrapper.lambda().in(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestions());
}
if (chatMemoryFilter.getStatus() != null) {
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, chatMemoryFilter.getStatus());
}
if (chatMemoryFilter.getHumanReviewRet() != null) {
queryWrapper.lambda().eq(ChatMemoryDO::getHumanReviewRet, chatMemoryFilter.getHumanReviewRet());
}
if (chatMemoryFilter.getLlmReviewRet() != null) {
queryWrapper.lambda().eq(ChatMemoryDO::getLlmReviewRet, chatMemoryFilter.getLlmReviewRet());
}
return chatMemoryRepository.getMemories(queryWrapper);
}
@Override
public List<ChatMemoryDO> getMemoriesForLlmReview() {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
.isNull(ChatMemoryDO::getLlmReviewRet);
return chatMemoryRepository.getMemories(queryWrapper);
}
private void enableMemory(ChatMemoryDO memory) {
@@ -50,6 +85,15 @@ public class MemoryServiceImpl implements MemoryService {
.dbSchema(memory.getDbSchema())
.sql(memory.getS2sql())
.build());
memory.setStatus(Status.ENABLED);
}
private void disableMemory(ChatMemoryDO memory) {
exemplarService.removeExemplar(memory.getAgentId().toString(),
SqlExemplar.builder()
.question(memory.getQuestion())
.dbSchema(memory.getDbSchema())
.sql(memory.getS2sql())
.build());
}
}

View File

@@ -7,6 +7,8 @@ import java.util.List;
public interface ExemplarService {
void storeExemplar(String collection, SqlExemplar exemplar);
void removeExemplar(String collection, SqlExemplar exemplar);
List<SqlExemplar> recallExemplars(String collection, String query, int num);
List<SqlExemplar> recallExemplars(String query, int num);

View File

@@ -49,6 +49,14 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
embeddingService.addQuery(collection, Lists.newArrayList(segment));
}
public void removeExemplar(String collection, SqlExemplar exemplar) {
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
String.class, Object.class));
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
embeddingService.deleteQuery(collection, Lists.newArrayList(segment));
}
public List<SqlExemplar> recallExemplars(String query, int num) {
String collection = embeddingConfig.getText2sqlCollectionName();
return recallExemplars(collection, query, num);

View File

@@ -329,3 +329,22 @@ alter table s2_agent add column `visual_config` varchar(2000) COLLATE utf8_unic
alter table s2_term add column `related_metrics` varchar(1000) DEFAULT NULL COMMENT '术语关联的指标';
alter table s2_term add column `related_dimensions` varchar(1000) DEFAULT NULL COMMENT '术语关联的维度';
--20240627
CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
`id` INT NOT NULL AUTO_INCREMENT,
`question` varchar(655) ,
`agent_id` INT ,
`db_schema` TEXT ,
`s2_sql` TEXT ,
`status` char(10) ,
`llm_review` char(10) ,
`llm_comment` TEXT,
`human_review` char(10) ,
`human_comment` TEXT ,
`created_at` datetime NOT NULL ,
`updated_at` datetime NOT NULL ,
`created_by` varchar(100) NOT NULL ,
`updated_by` varchar(100) NOT NULL ,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;

View File

@@ -147,8 +147,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
`llm_comment` TEXT,
`human_review` char(10) ,
`human_comment` TEXT ,
`created_at` TIMESTAMP NOT NULL ,
`updated_at` TIMESTAMP NOT NULL ,
`created_at` datetime NOT NULL ,
`updated_at` datetime NOT NULL ,
`created_by` varchar(100) NOT NULL ,
`updated_by` varchar(100) NOT NULL ,
PRIMARY KEY (`id`)