(improvement)(chat) Opt update of ChatMemory (#1286)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-06-29 17:39:01 +08:00
committed by GitHub
parent 82e2a7e1a3
commit a3f17b3b68
12 changed files with 88 additions and 54 deletions

View File

@@ -15,6 +15,8 @@ import java.util.List;
@NoArgsConstructor @NoArgsConstructor
public class ChatMemoryFilter { public class ChatMemoryFilter {
private Integer agentId;
private String question; private String question;
private List<String> questions; private List<String> questions;

View File

@@ -0,0 +1,27 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import javax.validation.constraints.NotNull;
@Data
public class ChatMemoryUpdateReq extends RecordInfo {
@NotNull(message = "id不可为空")
private Long id;
private String dbSchema;
private String s2sql;
private MemoryStatus status;
private MemoryReviewResult humanReviewRet;
private String humanReviewCmt;
}

View File

@@ -6,10 +6,8 @@ import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString; import lombok.ToString;
import java.util.Date; import java.util.Date;
@@ -17,8 +15,6 @@ import java.util.Date;
@Data @Data
@Builder @Builder
@ToString @ToString
@AllArgsConstructor
@NoArgsConstructor
@TableName("s2_chat_memory") @TableName("s2_chat_memory")
public class ChatMemoryDO { public class ChatMemoryDO {
@TableId(type = IdType.AUTO) @TableId(type = IdType.AUTO)

View File

@@ -50,7 +50,7 @@ public class ChatPlugin extends RecordInfo {
return Lists.newArrayList(); return Lists.newArrayList();
} }
public boolean isContainsAllModel() { public boolean isContainsAllDataSet() {
return CollectionUtils.isNotEmpty(dataSetList) && dataSetList.contains(-1L); return CollectionUtils.isNotEmpty(dataSetList) && dataSetList.contains(-1L);
} }

View File

@@ -193,26 +193,26 @@ public class PluginManager {
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) { public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo(); SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, chatParseContext); Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) { if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
return Pair.of(false, Sets.newHashSet()); return Pair.of(false, Sets.newHashSet());
} }
List<ParamOption> paramOptions = getSemanticOption(plugin); List<ParamOption> paramOptions = getSemanticOption(plugin);
if (CollectionUtils.isEmpty(paramOptions)) { if (CollectionUtils.isEmpty(paramOptions)) {
return Pair.of(true, pluginMatchedModel); return Pair.of(true, pluginMatchedDataSet);
} }
Set<Long> matchedModel = Sets.newHashSet(); Set<Long> matchedDataSet = Sets.newHashSet();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream() Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
.collect(Collectors.groupingBy(ParamOption::getModelId)); .collect(Collectors.groupingBy(ParamOption::getDataSetId));
for (Long modelId : paramOptionMap.keySet()) { for (Long dataSetId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(modelId); List<ParamOption> params = paramOptionMap.get(dataSetId);
if (CollectionUtils.isEmpty(params)) { if (CollectionUtils.isEmpty(params)) {
matchedModel.add(modelId); matchedDataSet.add(dataSetId);
continue; continue;
} }
boolean matched = true; boolean matched = true;
for (ParamOption paramOption : params) { for (ParamOption paramOption : params) {
Set<Long> elementIdSet = getSchemaElementMatch(modelId, schemaMapInfo); Set<Long> elementIdSet = getSchemaElementMatch(dataSetId, schemaMapInfo);
if (CollectionUtils.isEmpty(elementIdSet)) { if (CollectionUtils.isEmpty(elementIdSet)) {
matched = false; matched = false;
break; break;
@@ -223,13 +223,13 @@ public class PluginManager {
} }
} }
if (matched) { if (matched) {
matchedModel.add(modelId); matchedDataSet.add(dataSetId);
} }
} }
if (CollectionUtils.isEmpty(matchedModel)) { if (CollectionUtils.isEmpty(matchedDataSet)) {
return Pair.of(false, Sets.newHashSet()); return Pair.of(false, Sets.newHashSet());
} }
return Pair.of(true, matchedModel); return Pair.of(true, matchedDataSet);
} }
private static Set<Long> getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) { private static Set<Long> getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
@@ -251,7 +251,7 @@ public class PluginManager {
return null; return null;
} }
List<ParamOption> paramOptions = webBase.getParamOptions(); List<ParamOption> paramOptions = webBase.getParamOptions();
if (org.springframework.util.CollectionUtils.isEmpty(paramOptions)) { if (CollectionUtils.isEmpty(paramOptions)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
return paramOptions.stream() return paramOptions.stream()
@@ -259,19 +259,19 @@ public class PluginManager {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private static Set<Long> getPluginMatchedModel(ChatPlugin plugin, ChatParseContext chatParseContext) { private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos(); Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
if (plugin.isContainsAllModel()) { if (plugin.isContainsAllDataSet()) {
return Sets.newHashSet(plugin.getDefaultMode()); return Sets.newHashSet(plugin.getDefaultMode());
} }
List<Long> modelIds = plugin.getDataSetList(); List<Long> dataSetList = plugin.getDataSetList();
Set<Long> pluginMatchedModel = Sets.newHashSet(); Set<Long> pluginMatchedDataSet = Sets.newHashSet();
for (Long modelId : modelIds) { for (Long dataSetId : dataSetList) {
if (matchedDataSets.contains(modelId)) { if (matchedDataSets.contains(dataSetId)) {
pluginMatchedModel.add(modelId); pluginMatchedDataSet.add(dataSetId);
} }
} }
return pluginMatchedModel; return pluginMatchedDataSet;
} }
} }

View File

@@ -15,7 +15,7 @@ public class ParamOption {
private String keyAlias; private String keyAlias;
private Long modelId; private Long dataSetId;
private Long elementId; private Long elementId;

View File

@@ -77,8 +77,8 @@ public abstract class PluginSemanticQuery {
List<ParamOption> paramOptions = Lists.newArrayList(); List<ParamOption> paramOptions = Lists.newArrayList();
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) { if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
for (ParamOption paramOption : webPage.getParamOptions()) { for (ParamOption paramOption : webPage.getParamOptions()) {
if (paramOption.getModelId() != null if (paramOption.getDataSetId() != null
&& !parseInfo.getDataSetId().equals(paramOption.getModelId())) { && !parseInfo.getDataSetId().equals(paramOption.getDataSetId())) {
continue; continue;
} }
paramOptions.add(paramOption); paramOptions.add(paramOption);

View File

@@ -47,7 +47,7 @@ public abstract class PluginRecognizer {
PluginRecallResult pluginRecallResult) { PluginRecallResult pluginRecallResult) {
ChatPlugin plugin = pluginRecallResult.getPlugin(); ChatPlugin plugin = pluginRecallResult.getPlugin();
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds(); Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
if (plugin.isContainsAllModel()) { if (plugin.isContainsAllDataSet()) {
dataSetIds = Sets.newHashSet(-1L); dataSetIds = Sets.newHashSet(-1L);
} }
for (Long dataSetId : dataSetIds) { for (Long dataSetId : dataSetIds) {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.rest;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -14,7 +15,6 @@ import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.util.Date;
@RestController @RestController
@RequestMapping({"/api/chat/memory"}) @RequestMapping({"/api/chat/memory"})
@@ -23,27 +23,13 @@ public class MemoryController {
@Autowired @Autowired
private MemoryService memoryService; private MemoryService memoryService;
@PostMapping("/createMemory")
public Boolean createMemory(@RequestBody ChatMemoryDO memory,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
memory.setCreatedBy(user.getName());
memory.setUpdatedBy(user.getName());
memory.setCreatedAt(new Date());
memory.setUpdatedAt(new Date());
memoryService.createMemory(memory);
return true;
}
@PostMapping("/updateMemory") @PostMapping("/updateMemory")
public void updateMemory(ChatMemoryDO memory, public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
memory.setUpdatedBy(user.getName()); memoryService.updateMemory(chatMemoryUpdateReq, user);
memory.setUpdatedAt(new Date()); return true;
memoryService.updateMemory(memory);
} }
@RequestMapping("/pageMemories") @RequestMapping("/pageMemories")

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.chat.server.service; package com.tencent.supersonic.chat.server.service;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
@@ -10,6 +12,8 @@ import java.util.List;
public interface MemoryService { public interface MemoryService {
void createMemory(ChatMemoryDO memory); void createMemory(ChatMemoryDO memory);
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
void updateMemory(ChatMemoryDO memory); void updateMemory(ChatMemoryDO memory);
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq); PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);

View File

@@ -96,7 +96,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
if (memoriesExisted.contains(example)) { if (memoriesExisted.contains(example)) {
continue; continue;
} }
chatService.parseAndExecute(-1, agent.getId(), example); try {
chatService.parseAndExecute(-1, agent.getId(), example);
} catch (Exception e) {
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
}
} }
} }

View File

@@ -3,14 +3,17 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.BeanMapper;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -33,12 +36,21 @@ public class MemoryServiceImpl implements MemoryService {
} }
@Override @Override
public void updateMemory(ChatMemoryDO memory) { public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) {
if (MemoryStatus.ENABLED.equals(memory.getStatus())) { chatMemoryUpdateReq.updatedBy(user.getName());
enableMemory(memory); ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
} else if (MemoryStatus.DISABLED.equals(memory.getStatus())) { boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
disableMemory(memory); BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
enableMemory(chatMemoryDO);
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
disableMemory(chatMemoryDO);
} }
updateMemory(chatMemoryDO);
}
@Override
public void updateMemory(ChatMemoryDO memory) {
chatMemoryRepository.updateMemory(memory); chatMemoryRepository.updateMemory(memory);
} }
@@ -52,6 +64,9 @@ public class MemoryServiceImpl implements MemoryService {
@Override @Override
public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) { public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>(); QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
if (chatMemoryFilter.getAgentId() != null) {
queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryFilter.getAgentId());
}
if (StringUtils.isNotBlank(chatMemoryFilter.getQuestion())) { if (StringUtils.isNotBlank(chatMemoryFilter.getQuestion())) {
queryWrapper.lambda().like(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestion()); queryWrapper.lambda().like(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestion());
} }