[improvement][chat]Support reviewing query memory based on direct user feedback.

This commit is contained in:
jerryjzhang
2024-12-26 09:47:13 +08:00
parent 68963b9ec9
commit d04a086c88
15 changed files with 49 additions and 20 deletions

3
.gitignore vendored
View File

@@ -19,4 +19,5 @@ assembly/runtime/*
chm_db/ chm_db/
__pycache__/ __pycache__/
/dict /dict
assembly/build/*-SNAPSHOT assembly/build/*-SNAPSHOT
**/node_modules/

View File

@@ -17,6 +17,8 @@ public class ChatMemoryFilter {
private Integer agentId; private Integer agentId;
private Long queryId;
private String question; private String question;
private List<String> questions; private List<String> questions;

View File

@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatQueryExecutor {
Text2SQLExemplar.class); Text2SQLExemplar.class);
MemoryService memoryService = ContextUtils.getBean(MemoryService.class); MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemory.builder() memoryService.createMemory(ChatMemory.builder().queryId(queryResult.getQueryId())
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING) .agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo()) .question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql()) .dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())

View File

@@ -23,6 +23,9 @@ public class ChatMemoryDO {
@TableField("agent_id") @TableField("agent_id")
private Integer agentId; private Integer agentId;
@TableField("query_id")
private Long queryId;
@TableField("question") @TableField("question")
private String question; private String question;

View File

@@ -2,11 +2,7 @@ package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor; import lombok.*;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.util.Date; import java.util.Date;
@@ -20,6 +16,8 @@ public class ChatMemory {
private Integer agentId; private Integer agentId;
private Long queryId;
private String question; private String question;
private String sideInfo; private String sideInfo;

View File

@@ -53,7 +53,7 @@ public class ChatController {
} }
@PostMapping("/updateQAFeedback") @PostMapping("/updateQAFeedback")
public Boolean updateQAFeedback(@RequestParam(value = "id") Integer id, public Boolean updateQAFeedback(@RequestParam(value = "id") Long id,
@RequestParam(value = "score") Integer score, @RequestParam(value = "score") Integer score,
@RequestParam(value = "feedback", required = false) String feedback) { @RequestParam(value = "feedback", required = false) String feedback) {
return chatService.updateFeedback(id, score, feedback); return chatService.updateFeedback(id, score, feedback);

View File

@@ -13,7 +13,6 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import java.util.List; import java.util.List;
@@ -24,7 +23,7 @@ public interface ChatManageService {
boolean updateChatName(Long chatId, String chatName, String userName); boolean updateChatName(Long chatId, String chatName, String userName);
boolean updateFeedback(Integer id, Integer score, String feedback); boolean updateFeedback(Long id, Integer score, String feedback);
boolean updateChatIsTop(Long chatId, int isTop); boolean updateChatIsTop(Long chatId, int isTop);

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.request.*;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp; import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
@@ -15,11 +15,12 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -38,6 +39,8 @@ public class ChatManageServiceImpl implements ChatManageService {
private ChatRepository chatRepository; private ChatRepository chatRepository;
@Autowired @Autowired
private ChatQueryRepository chatQueryRepository; private ChatQueryRepository chatQueryRepository;
@Autowired
private MemoryService memoryService;
@Override @Override
public Long addChat(User user, String chatName, Integer agentId) { public Long addChat(User user, String chatName, Integer agentId) {
@@ -64,11 +67,28 @@ public class ChatManageServiceImpl implements ChatManageService {
} }
@Override @Override
public boolean updateFeedback(Integer id, Integer score, String feedback) { public boolean updateFeedback(Long id, Integer score, String feedback) {
QueryDO intelligentQueryDO = new QueryDO(); QueryDO intelligentQueryDO = new QueryDO();
intelligentQueryDO.setId(id); intelligentQueryDO.setId(id);
intelligentQueryDO.setQuestionId(id);
intelligentQueryDO.setScore(score); intelligentQueryDO.setScore(score);
intelligentQueryDO.setFeedback(feedback); intelligentQueryDO.setFeedback(feedback);
// enable or disable memory based on user feedback
if (score >= 5 || score <= 1) {
ChatMemoryFilter memoryFilter = ChatMemoryFilter.builder().queryId(id).build();
List<ChatMemory> memories = memoryService.getMemories(memoryFilter);
memories.forEach(m -> {
MemoryStatus status = score >= 5 ? MemoryStatus.ENABLED : MemoryStatus.DISABLED;
MemoryReviewResult reviewResult =
score >= 5 ? MemoryReviewResult.POSITIVE : MemoryReviewResult.NEGATIVE;
ChatMemoryUpdateReq memoryUpdateReq = ChatMemoryUpdateReq.builder().id(m.getId())
.status(status).humanReviewRet(reviewResult)
.humanReviewCmt("Reviewed as per user feedback").build();
memoryService.updateMemory(memoryUpdateReq, User.getDefaultUser());
});
}
return chatRepository.updateFeedback(intelligentQueryDO); return chatRepository.updateFeedback(intelligentQueryDO);
} }

View File

@@ -140,7 +140,7 @@ public class MemoryServiceImpl implements MemoryService {
return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList()); return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList());
} }
private void enableMemory(ChatMemoryDO memory) { public void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED.toString()); memory.setStatus(MemoryStatus.ENABLED.toString());
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion()) Text2SQLExemplar.builder().question(memory.getQuestion())
@@ -148,7 +148,7 @@ public class MemoryServiceImpl implements MemoryService {
.sql(memory.getS2sql()).build()); .sql(memory.getS2sql()).build());
} }
private void disableMemory(ChatMemoryDO memory) { public void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED.toString()); memory.setStatus(MemoryStatus.DISABLED.toString());
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion()) Text2SQLExemplar.builder().question(memory.getQuestion())

View File

@@ -190,7 +190,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
return relateDimension; return relateDimension;
} }
protected void updateQueryScore(Integer queryId) { protected void updateQueryScore(Long queryId) {
chatManageService.updateFeedback(queryId, 5, ""); chatManageService.updateFeedback(queryId, 5, "");
} }

View File

@@ -115,8 +115,8 @@ public class S2VisitsDemo extends S2BaseDemo {
// create agent // create agent
Integer agentId = addAgent(s2DataSet.getId()); Integer agentId = addAgent(s2DataSet.getId());
addSampleChats(agentId); addSampleChats(agentId);
updateQueryScore(1); updateQueryScore(1L);
updateQueryScore(4); updateQueryScore(4L);
} catch (Exception e) { } catch (Exception e) {
log.error("Failed to add S2Visits demo data", e); log.error("Failed to add S2Visits demo data", e);
} }

View File

@@ -402,3 +402,6 @@ alter table s2_agent add column `viewer` varchar(1000) COLLATE utf8_unicode_ci D
--20241201 --20241201
ALTER TABLE s2_query_stat_info RENAME COLUMN `user` TO `query_user`; ALTER TABLE s2_query_stat_info RENAME COLUMN `user` TO `query_user`;
ALTER TABLE s2_chat_context RENAME COLUMN `user` TO `query_user`; ALTER TABLE s2_chat_context RENAME COLUMN `user` TO `query_user`;
--20241226
alter table s2_chat_memory add column `query_id` BIGINT DEFAULT NULL;

View File

@@ -86,6 +86,7 @@ COMMENT ON TABLE s2_chat_config IS 'chat config information table ';
CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
`id` INT NOT NULL AUTO_INCREMENT, `id` INT NOT NULL AUTO_INCREMENT,
`question` varchar(655) , `question` varchar(655) ,
`query_id` BIGINT ,
`agent_id` INT , `agent_id` INT ,
`db_schema` TEXT , `db_schema` TEXT ,
`s2_sql` TEXT , `s2_sql` TEXT ,

View File

@@ -77,6 +77,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
`id` INT NOT NULL AUTO_INCREMENT, `id` INT NOT NULL AUTO_INCREMENT,
`question` varchar(655) COMMENT '用户问题' , `question` varchar(655) COMMENT '用户问题' ,
`side_info` TEXT COMMENT '辅助信息' , `side_info` TEXT COMMENT '辅助信息' ,
`query_id` BIGINT COMMENT '问答ID' ,
`agent_id` INT COMMENT '助理ID' , `agent_id` INT COMMENT '助理ID' ,
`db_schema` TEXT COMMENT 'Schema映射' , `db_schema` TEXT COMMENT 'Schema映射' ,
`s2_sql` TEXT COMMENT '大模型解析SQL' , `s2_sql` TEXT COMMENT '大模型解析SQL' ,

View File

@@ -70,6 +70,7 @@ CREATE TABLE IF NOT EXISTS s2_chat_memory (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
question varchar(655), question varchar(655),
side_info TEXT, side_info TEXT,
query_id bigint,
agent_id INTEGER, agent_id INTEGER,
db_schema TEXT, db_schema TEXT,
s2_sql TEXT, s2_sql TEXT,