mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) Opt update of ChatMemory (#1286)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -15,6 +15,8 @@ import java.util.List;
|
||||
@NoArgsConstructor
|
||||
public class ChatMemoryFilter {
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
private String question;
|
||||
|
||||
private List<String> questions;
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
@Data
|
||||
public class ChatMemoryUpdateReq extends RecordInfo {
|
||||
|
||||
@NotNull(message = "id不可为空")
|
||||
private Long id;
|
||||
|
||||
private String dbSchema;
|
||||
|
||||
private String s2sql;
|
||||
|
||||
private MemoryStatus status;
|
||||
|
||||
private MemoryReviewResult humanReviewRet;
|
||||
|
||||
private String humanReviewCmt;
|
||||
|
||||
}
|
||||
@@ -6,10 +6,8 @@ import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Date;
|
||||
@@ -17,8 +15,6 @@ import java.util.Date;
|
||||
@Data
|
||||
@Builder
|
||||
@ToString
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@TableName("s2_chat_memory")
|
||||
public class ChatMemoryDO {
|
||||
@TableId(type = IdType.AUTO)
|
||||
|
||||
@@ -50,7 +50,7 @@ public class ChatPlugin extends RecordInfo {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
public boolean isContainsAllDataSet() {
|
||||
return CollectionUtils.isNotEmpty(dataSetList) && dataSetList.contains(-1L);
|
||||
}
|
||||
|
||||
|
||||
@@ -193,26 +193,26 @@ public class PluginManager {
|
||||
|
||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, chatParseContext);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
List<ParamOption> paramOptions = getSemanticOption(plugin);
|
||||
if (CollectionUtils.isEmpty(paramOptions)) {
|
||||
return Pair.of(true, pluginMatchedModel);
|
||||
return Pair.of(true, pluginMatchedDataSet);
|
||||
}
|
||||
Set<Long> matchedModel = Sets.newHashSet();
|
||||
Set<Long> matchedDataSet = Sets.newHashSet();
|
||||
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
|
||||
.collect(Collectors.groupingBy(ParamOption::getModelId));
|
||||
for (Long modelId : paramOptionMap.keySet()) {
|
||||
List<ParamOption> params = paramOptionMap.get(modelId);
|
||||
.collect(Collectors.groupingBy(ParamOption::getDataSetId));
|
||||
for (Long dataSetId : paramOptionMap.keySet()) {
|
||||
List<ParamOption> params = paramOptionMap.get(dataSetId);
|
||||
if (CollectionUtils.isEmpty(params)) {
|
||||
matchedModel.add(modelId);
|
||||
matchedDataSet.add(dataSetId);
|
||||
continue;
|
||||
}
|
||||
boolean matched = true;
|
||||
for (ParamOption paramOption : params) {
|
||||
Set<Long> elementIdSet = getSchemaElementMatch(modelId, schemaMapInfo);
|
||||
Set<Long> elementIdSet = getSchemaElementMatch(dataSetId, schemaMapInfo);
|
||||
if (CollectionUtils.isEmpty(elementIdSet)) {
|
||||
matched = false;
|
||||
break;
|
||||
@@ -223,13 +223,13 @@ public class PluginManager {
|
||||
}
|
||||
}
|
||||
if (matched) {
|
||||
matchedModel.add(modelId);
|
||||
matchedDataSet.add(dataSetId);
|
||||
}
|
||||
}
|
||||
if (CollectionUtils.isEmpty(matchedModel)) {
|
||||
if (CollectionUtils.isEmpty(matchedDataSet)) {
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
return Pair.of(true, matchedModel);
|
||||
return Pair.of(true, matchedDataSet);
|
||||
}
|
||||
|
||||
private static Set<Long> getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
||||
@@ -251,7 +251,7 @@ public class PluginManager {
|
||||
return null;
|
||||
}
|
||||
List<ParamOption> paramOptions = webBase.getParamOptions();
|
||||
if (org.springframework.util.CollectionUtils.isEmpty(paramOptions)) {
|
||||
if (CollectionUtils.isEmpty(paramOptions)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return paramOptions.stream()
|
||||
@@ -259,19 +259,19 @@ public class PluginManager {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedModel(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
if (plugin.isContainsAllDataSet()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
List<Long> modelIds = plugin.getDataSetList();
|
||||
Set<Long> pluginMatchedModel = Sets.newHashSet();
|
||||
for (Long modelId : modelIds) {
|
||||
if (matchedDataSets.contains(modelId)) {
|
||||
pluginMatchedModel.add(modelId);
|
||||
List<Long> dataSetList = plugin.getDataSetList();
|
||||
Set<Long> pluginMatchedDataSet = Sets.newHashSet();
|
||||
for (Long dataSetId : dataSetList) {
|
||||
if (matchedDataSets.contains(dataSetId)) {
|
||||
pluginMatchedDataSet.add(dataSetId);
|
||||
}
|
||||
}
|
||||
return pluginMatchedModel;
|
||||
return pluginMatchedDataSet;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ public class ParamOption {
|
||||
|
||||
private String keyAlias;
|
||||
|
||||
private Long modelId;
|
||||
private Long dataSetId;
|
||||
|
||||
private Long elementId;
|
||||
|
||||
|
||||
@@ -77,8 +77,8 @@ public abstract class PluginSemanticQuery {
|
||||
List<ParamOption> paramOptions = Lists.newArrayList();
|
||||
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
|
||||
for (ParamOption paramOption : webPage.getParamOptions()) {
|
||||
if (paramOption.getModelId() != null
|
||||
&& !parseInfo.getDataSetId().equals(paramOption.getModelId())) {
|
||||
if (paramOption.getDataSetId() != null
|
||||
&& !parseInfo.getDataSetId().equals(paramOption.getDataSetId())) {
|
||||
continue;
|
||||
}
|
||||
paramOptions.add(paramOption);
|
||||
|
||||
@@ -47,7 +47,7 @@ public abstract class PluginRecognizer {
|
||||
PluginRecallResult pluginRecallResult) {
|
||||
ChatPlugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
if (plugin.isContainsAllDataSet()) {
|
||||
dataSetIds = Sets.newHashSet(-1L);
|
||||
}
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.rest;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
@@ -14,7 +15,6 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.Date;
|
||||
|
||||
@RestController
|
||||
@RequestMapping({"/api/chat/memory"})
|
||||
@@ -23,27 +23,13 @@ public class MemoryController {
|
||||
@Autowired
|
||||
private MemoryService memoryService;
|
||||
|
||||
@PostMapping("/createMemory")
|
||||
public Boolean createMemory(@RequestBody ChatMemoryDO memory,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
memory.setCreatedBy(user.getName());
|
||||
memory.setUpdatedBy(user.getName());
|
||||
memory.setCreatedAt(new Date());
|
||||
memory.setUpdatedAt(new Date());
|
||||
memoryService.createMemory(memory);
|
||||
return true;
|
||||
}
|
||||
|
||||
@PostMapping("/updateMemory")
|
||||
public void updateMemory(ChatMemoryDO memory,
|
||||
public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
memory.setUpdatedBy(user.getName());
|
||||
memory.setUpdatedAt(new Date());
|
||||
memoryService.updateMemory(memory);
|
||||
memoryService.updateMemory(chatMemoryUpdateReq, user);
|
||||
return true;
|
||||
}
|
||||
|
||||
@RequestMapping("/pageMemories")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
|
||||
@@ -10,6 +12,8 @@ import java.util.List;
|
||||
public interface MemoryService {
|
||||
void createMemory(ChatMemoryDO memory);
|
||||
|
||||
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
|
||||
|
||||
void updateMemory(ChatMemoryDO memory);
|
||||
|
||||
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
|
||||
|
||||
@@ -96,7 +96,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
if (memoriesExisted.contains(example)) {
|
||||
continue;
|
||||
}
|
||||
chatService.parseAndExecute(-1, agent.getId(), example);
|
||||
try {
|
||||
chatService.parseAndExecute(-1, agent.getId(), example);
|
||||
} catch (Exception e) {
|
||||
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,14 +3,17 @@ package com.tencent.supersonic.chat.server.service.impl;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -33,12 +36,21 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateMemory(ChatMemoryDO memory) {
|
||||
if (MemoryStatus.ENABLED.equals(memory.getStatus())) {
|
||||
enableMemory(memory);
|
||||
} else if (MemoryStatus.DISABLED.equals(memory.getStatus())) {
|
||||
disableMemory(memory);
|
||||
public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) {
|
||||
chatMemoryUpdateReq.updatedBy(user.getName());
|
||||
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
|
||||
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
|
||||
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
|
||||
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
|
||||
enableMemory(chatMemoryDO);
|
||||
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
|
||||
disableMemory(chatMemoryDO);
|
||||
}
|
||||
updateMemory(chatMemoryDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateMemory(ChatMemoryDO memory) {
|
||||
chatMemoryRepository.updateMemory(memory);
|
||||
}
|
||||
|
||||
@@ -52,6 +64,9 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) {
|
||||
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||
if (chatMemoryFilter.getAgentId() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryFilter.getAgentId());
|
||||
}
|
||||
if (StringUtils.isNotBlank(chatMemoryFilter.getQuestion())) {
|
||||
queryWrapper.lambda().like(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestion());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user