mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement][chat]Support reviewing query memory based on direct user feedback.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -20,3 +20,4 @@ chm_db/
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
/dict
|
/dict
|
||||||
assembly/build/*-SNAPSHOT
|
assembly/build/*-SNAPSHOT
|
||||||
|
**/node_modules/
|
||||||
@@ -17,6 +17,8 @@ public class ChatMemoryFilter {
|
|||||||
|
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
|
|
||||||
|
private Long queryId;
|
||||||
|
|
||||||
private String question;
|
private String question;
|
||||||
|
|
||||||
private List<String> questions;
|
private List<String> questions;
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatQueryExecutor {
|
|||||||
Text2SQLExemplar.class);
|
Text2SQLExemplar.class);
|
||||||
|
|
||||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||||
memoryService.createMemory(ChatMemory.builder()
|
memoryService.createMemory(ChatMemory.builder().queryId(queryResult.getQueryId())
|
||||||
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
|
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
|
||||||
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
|
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
|
||||||
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
|
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ public class ChatMemoryDO {
|
|||||||
@TableField("agent_id")
|
@TableField("agent_id")
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
|
|
||||||
|
@TableField("query_id")
|
||||||
|
private Long queryId;
|
||||||
|
|
||||||
@TableField("question")
|
@TableField("question")
|
||||||
private String question;
|
private String question;
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,7 @@ package com.tencent.supersonic.chat.server.pojo;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.*;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.ToString;
|
|
||||||
|
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
|
||||||
@@ -20,6 +16,8 @@ public class ChatMemory {
|
|||||||
|
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
|
|
||||||
|
private Long queryId;
|
||||||
|
|
||||||
private String question;
|
private String question;
|
||||||
|
|
||||||
private String sideInfo;
|
private String sideInfo;
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ public class ChatController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("/updateQAFeedback")
|
@PostMapping("/updateQAFeedback")
|
||||||
public Boolean updateQAFeedback(@RequestParam(value = "id") Integer id,
|
public Boolean updateQAFeedback(@RequestParam(value = "id") Long id,
|
||||||
@RequestParam(value = "score") Integer score,
|
@RequestParam(value = "score") Integer score,
|
||||||
@RequestParam(value = "feedback", required = false) String feedback) {
|
@RequestParam(value = "feedback", required = false) String feedback) {
|
||||||
return chatService.updateFeedback(id, score, feedback);
|
return chatService.updateFeedback(id, score, feedback);
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -24,7 +23,7 @@ public interface ChatManageService {
|
|||||||
|
|
||||||
boolean updateChatName(Long chatId, String chatName, String userName);
|
boolean updateChatName(Long chatId, String chatName, String userName);
|
||||||
|
|
||||||
boolean updateFeedback(Integer id, Integer score, String feedback);
|
boolean updateFeedback(Long id, Integer score, String feedback);
|
||||||
|
|
||||||
boolean updateChatIsTop(Long chatId, int isTop);
|
boolean updateChatIsTop(Long chatId, int isTop);
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package com.tencent.supersonic.chat.server.service.impl;
|
|||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.github.pagehelper.PageInfo;
|
import com.github.pagehelper.PageInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
import com.tencent.supersonic.chat.api.pojo.request.*;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
@@ -15,11 +15,12 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -38,6 +39,8 @@ public class ChatManageServiceImpl implements ChatManageService {
|
|||||||
private ChatRepository chatRepository;
|
private ChatRepository chatRepository;
|
||||||
@Autowired
|
@Autowired
|
||||||
private ChatQueryRepository chatQueryRepository;
|
private ChatQueryRepository chatQueryRepository;
|
||||||
|
@Autowired
|
||||||
|
private MemoryService memoryService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Long addChat(User user, String chatName, Integer agentId) {
|
public Long addChat(User user, String chatName, Integer agentId) {
|
||||||
@@ -64,11 +67,28 @@ public class ChatManageServiceImpl implements ChatManageService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean updateFeedback(Integer id, Integer score, String feedback) {
|
public boolean updateFeedback(Long id, Integer score, String feedback) {
|
||||||
QueryDO intelligentQueryDO = new QueryDO();
|
QueryDO intelligentQueryDO = new QueryDO();
|
||||||
intelligentQueryDO.setId(id);
|
intelligentQueryDO.setId(id);
|
||||||
|
intelligentQueryDO.setQuestionId(id);
|
||||||
intelligentQueryDO.setScore(score);
|
intelligentQueryDO.setScore(score);
|
||||||
intelligentQueryDO.setFeedback(feedback);
|
intelligentQueryDO.setFeedback(feedback);
|
||||||
|
|
||||||
|
// enable or disable memory based on user feedback
|
||||||
|
if (score >= 5 || score <= 1) {
|
||||||
|
ChatMemoryFilter memoryFilter = ChatMemoryFilter.builder().queryId(id).build();
|
||||||
|
List<ChatMemory> memories = memoryService.getMemories(memoryFilter);
|
||||||
|
memories.forEach(m -> {
|
||||||
|
MemoryStatus status = score >= 5 ? MemoryStatus.ENABLED : MemoryStatus.DISABLED;
|
||||||
|
MemoryReviewResult reviewResult =
|
||||||
|
score >= 5 ? MemoryReviewResult.POSITIVE : MemoryReviewResult.NEGATIVE;
|
||||||
|
ChatMemoryUpdateReq memoryUpdateReq = ChatMemoryUpdateReq.builder().id(m.getId())
|
||||||
|
.status(status).humanReviewRet(reviewResult)
|
||||||
|
.humanReviewCmt("Reviewed as per user feedback").build();
|
||||||
|
memoryService.updateMemory(memoryUpdateReq, User.getDefaultUser());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return chatRepository.updateFeedback(intelligentQueryDO);
|
return chatRepository.updateFeedback(intelligentQueryDO);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList());
|
return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void enableMemory(ChatMemoryDO memory) {
|
public void enableMemory(ChatMemoryDO memory) {
|
||||||
memory.setStatus(MemoryStatus.ENABLED.toString());
|
memory.setStatus(MemoryStatus.ENABLED.toString());
|
||||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
Text2SQLExemplar.builder().question(memory.getQuestion())
|
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||||
@@ -148,7 +148,7 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
.sql(memory.getS2sql()).build());
|
.sql(memory.getS2sql()).build());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void disableMemory(ChatMemoryDO memory) {
|
public void disableMemory(ChatMemoryDO memory) {
|
||||||
memory.setStatus(MemoryStatus.DISABLED.toString());
|
memory.setStatus(MemoryStatus.DISABLED.toString());
|
||||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
Text2SQLExemplar.builder().question(memory.getQuestion())
|
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
|||||||
return relateDimension;
|
return relateDimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void updateQueryScore(Integer queryId) {
|
protected void updateQueryScore(Long queryId) {
|
||||||
chatManageService.updateFeedback(queryId, 5, "");
|
chatManageService.updateFeedback(queryId, 5, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -115,8 +115,8 @@ public class S2VisitsDemo extends S2BaseDemo {
|
|||||||
// create agent
|
// create agent
|
||||||
Integer agentId = addAgent(s2DataSet.getId());
|
Integer agentId = addAgent(s2DataSet.getId());
|
||||||
addSampleChats(agentId);
|
addSampleChats(agentId);
|
||||||
updateQueryScore(1);
|
updateQueryScore(1L);
|
||||||
updateQueryScore(4);
|
updateQueryScore(4L);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("Failed to add S2Visits demo data", e);
|
log.error("Failed to add S2Visits demo data", e);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -402,3 +402,6 @@ alter table s2_agent add column `viewer` varchar(1000) COLLATE utf8_unicode_ci D
|
|||||||
--20241201
|
--20241201
|
||||||
ALTER TABLE s2_query_stat_info RENAME COLUMN `user` TO `query_user`;
|
ALTER TABLE s2_query_stat_info RENAME COLUMN `user` TO `query_user`;
|
||||||
ALTER TABLE s2_chat_context RENAME COLUMN `user` TO `query_user`;
|
ALTER TABLE s2_chat_context RENAME COLUMN `user` TO `query_user`;
|
||||||
|
|
||||||
|
--20241226
|
||||||
|
alter table s2_chat_memory add column `query_id` BIGINT DEFAULT NULL;
|
||||||
@@ -86,6 +86,7 @@ COMMENT ON TABLE s2_chat_config IS 'chat config information table ';
|
|||||||
CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
||||||
`id` INT NOT NULL AUTO_INCREMENT,
|
`id` INT NOT NULL AUTO_INCREMENT,
|
||||||
`question` varchar(655) ,
|
`question` varchar(655) ,
|
||||||
|
`query_id` BIGINT ,
|
||||||
`agent_id` INT ,
|
`agent_id` INT ,
|
||||||
`db_schema` TEXT ,
|
`db_schema` TEXT ,
|
||||||
`s2_sql` TEXT ,
|
`s2_sql` TEXT ,
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
|||||||
`id` INT NOT NULL AUTO_INCREMENT,
|
`id` INT NOT NULL AUTO_INCREMENT,
|
||||||
`question` varchar(655) COMMENT '用户问题' ,
|
`question` varchar(655) COMMENT '用户问题' ,
|
||||||
`side_info` TEXT COMMENT '辅助信息' ,
|
`side_info` TEXT COMMENT '辅助信息' ,
|
||||||
|
`query_id` BIGINT COMMENT '问答ID' ,
|
||||||
`agent_id` INT COMMENT '助理ID' ,
|
`agent_id` INT COMMENT '助理ID' ,
|
||||||
`db_schema` TEXT COMMENT 'Schema映射' ,
|
`db_schema` TEXT COMMENT 'Schema映射' ,
|
||||||
`s2_sql` TEXT COMMENT '大模型解析SQL' ,
|
`s2_sql` TEXT COMMENT '大模型解析SQL' ,
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ CREATE TABLE IF NOT EXISTS s2_chat_memory (
|
|||||||
id SERIAL PRIMARY KEY,
|
id SERIAL PRIMARY KEY,
|
||||||
question varchar(655),
|
question varchar(655),
|
||||||
side_info TEXT,
|
side_info TEXT,
|
||||||
|
query_id bigint,
|
||||||
agent_id INTEGER,
|
agent_id INTEGER,
|
||||||
db_schema TEXT,
|
db_schema TEXT,
|
||||||
s2_sql TEXT,
|
s2_sql TEXT,
|
||||||
|
|||||||
Reference in New Issue
Block a user