(improvement)(chat) Opt update of ChatMemory (#1286)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-06-29 17:39:01 +08:00
committed by GitHub
parent 82e2a7e1a3
commit a3f17b3b68
12 changed files with 88 additions and 54 deletions

View File

@@ -15,6 +15,8 @@ import java.util.List;
@NoArgsConstructor
public class ChatMemoryFilter {
private Integer agentId;
private String question;
private List<String> questions;

View File

@@ -0,0 +1,27 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import javax.validation.constraints.NotNull;
@Data
public class ChatMemoryUpdateReq extends RecordInfo {
@NotNull(message = "id不可为空")
private Long id;
private String dbSchema;
private String s2sql;
private MemoryStatus status;
private MemoryReviewResult humanReviewRet;
private String humanReviewCmt;
}

View File

@@ -6,10 +6,8 @@ import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.util.Date;
@@ -17,8 +15,6 @@ import java.util.Date;
@Data
@Builder
@ToString
@AllArgsConstructor
@NoArgsConstructor
@TableName("s2_chat_memory")
public class ChatMemoryDO {
@TableId(type = IdType.AUTO)

View File

@@ -50,7 +50,7 @@ public class ChatPlugin extends RecordInfo {
return Lists.newArrayList();
}
public boolean isContainsAllModel() {
public boolean isContainsAllDataSet() {
return CollectionUtils.isNotEmpty(dataSetList) && dataSetList.contains(-1L);
}

View File

@@ -193,26 +193,26 @@ public class PluginManager {
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, chatParseContext);
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
return Pair.of(false, Sets.newHashSet());
}
List<ParamOption> paramOptions = getSemanticOption(plugin);
if (CollectionUtils.isEmpty(paramOptions)) {
return Pair.of(true, pluginMatchedModel);
return Pair.of(true, pluginMatchedDataSet);
}
Set<Long> matchedModel = Sets.newHashSet();
Set<Long> matchedDataSet = Sets.newHashSet();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
.collect(Collectors.groupingBy(ParamOption::getModelId));
for (Long modelId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(modelId);
.collect(Collectors.groupingBy(ParamOption::getDataSetId));
for (Long dataSetId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(dataSetId);
if (CollectionUtils.isEmpty(params)) {
matchedModel.add(modelId);
matchedDataSet.add(dataSetId);
continue;
}
boolean matched = true;
for (ParamOption paramOption : params) {
Set<Long> elementIdSet = getSchemaElementMatch(modelId, schemaMapInfo);
Set<Long> elementIdSet = getSchemaElementMatch(dataSetId, schemaMapInfo);
if (CollectionUtils.isEmpty(elementIdSet)) {
matched = false;
break;
@@ -223,13 +223,13 @@ public class PluginManager {
}
}
if (matched) {
matchedModel.add(modelId);
matchedDataSet.add(dataSetId);
}
}
if (CollectionUtils.isEmpty(matchedModel)) {
if (CollectionUtils.isEmpty(matchedDataSet)) {
return Pair.of(false, Sets.newHashSet());
}
return Pair.of(true, matchedModel);
return Pair.of(true, matchedDataSet);
}
private static Set<Long> getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
@@ -251,7 +251,7 @@ public class PluginManager {
return null;
}
List<ParamOption> paramOptions = webBase.getParamOptions();
if (org.springframework.util.CollectionUtils.isEmpty(paramOptions)) {
if (CollectionUtils.isEmpty(paramOptions)) {
return Lists.newArrayList();
}
return paramOptions.stream()
@@ -259,19 +259,19 @@ public class PluginManager {
.collect(Collectors.toList());
}
private static Set<Long> getPluginMatchedModel(ChatPlugin plugin, ChatParseContext chatParseContext) {
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
if (plugin.isContainsAllModel()) {
if (plugin.isContainsAllDataSet()) {
return Sets.newHashSet(plugin.getDefaultMode());
}
List<Long> modelIds = plugin.getDataSetList();
Set<Long> pluginMatchedModel = Sets.newHashSet();
for (Long modelId : modelIds) {
if (matchedDataSets.contains(modelId)) {
pluginMatchedModel.add(modelId);
List<Long> dataSetList = plugin.getDataSetList();
Set<Long> pluginMatchedDataSet = Sets.newHashSet();
for (Long dataSetId : dataSetList) {
if (matchedDataSets.contains(dataSetId)) {
pluginMatchedDataSet.add(dataSetId);
}
}
return pluginMatchedModel;
return pluginMatchedDataSet;
}
}

View File

@@ -15,7 +15,7 @@ public class ParamOption {
private String keyAlias;
private Long modelId;
private Long dataSetId;
private Long elementId;

View File

@@ -77,8 +77,8 @@ public abstract class PluginSemanticQuery {
List<ParamOption> paramOptions = Lists.newArrayList();
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
for (ParamOption paramOption : webPage.getParamOptions()) {
if (paramOption.getModelId() != null
&& !parseInfo.getDataSetId().equals(paramOption.getModelId())) {
if (paramOption.getDataSetId() != null
&& !parseInfo.getDataSetId().equals(paramOption.getDataSetId())) {
continue;
}
paramOptions.add(paramOption);

View File

@@ -47,7 +47,7 @@ public abstract class PluginRecognizer {
PluginRecallResult pluginRecallResult) {
ChatPlugin plugin = pluginRecallResult.getPlugin();
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
if (plugin.isContainsAllModel()) {
if (plugin.isContainsAllDataSet()) {
dataSetIds = Sets.newHashSet(-1L);
}
for (Long dataSetId : dataSetIds) {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -14,7 +15,6 @@ import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Date;
@RestController
@RequestMapping({"/api/chat/memory"})
@@ -23,27 +23,13 @@ public class MemoryController {
@Autowired
private MemoryService memoryService;
@PostMapping("/createMemory")
public Boolean createMemory(@RequestBody ChatMemoryDO memory,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
memory.setCreatedBy(user.getName());
memory.setUpdatedBy(user.getName());
memory.setCreatedAt(new Date());
memory.setUpdatedAt(new Date());
memoryService.createMemory(memory);
return true;
}
@PostMapping("/updateMemory")
public void updateMemory(ChatMemoryDO memory,
public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
memory.setUpdatedBy(user.getName());
memory.setUpdatedAt(new Date());
memoryService.updateMemory(memory);
memoryService.updateMemory(chatMemoryUpdateReq, user);
return true;
}
@RequestMapping("/pageMemories")

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.chat.server.service;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
@@ -10,6 +12,8 @@ import java.util.List;
public interface MemoryService {
void createMemory(ChatMemoryDO memory);
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
void updateMemory(ChatMemoryDO memory);
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);

View File

@@ -96,7 +96,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
if (memoriesExisted.contains(example)) {
continue;
}
chatService.parseAndExecute(-1, agent.getId(), example);
try {
chatService.parseAndExecute(-1, agent.getId(), example);
} catch (Exception e) {
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
}
}
}

View File

@@ -3,14 +3,17 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.BeanMapper;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -33,12 +36,21 @@ public class MemoryServiceImpl implements MemoryService {
}
@Override
public void updateMemory(ChatMemoryDO memory) {
if (MemoryStatus.ENABLED.equals(memory.getStatus())) {
enableMemory(memory);
} else if (MemoryStatus.DISABLED.equals(memory.getStatus())) {
disableMemory(memory);
public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) {
chatMemoryUpdateReq.updatedBy(user.getName());
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
enableMemory(chatMemoryDO);
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
disableMemory(chatMemoryDO);
}
updateMemory(chatMemoryDO);
}
@Override
public void updateMemory(ChatMemoryDO memory) {
chatMemoryRepository.updateMemory(memory);
}
@@ -52,6 +64,9 @@ public class MemoryServiceImpl implements MemoryService {
@Override
public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
if (chatMemoryFilter.getAgentId() != null) {
queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryFilter.getAgentId());
}
if (StringUtils.isNotBlank(chatMemoryFilter.getQuestion())) {
queryWrapper.lambda().like(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestion());
}