diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryFilter.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryFilter.java index 00ec855b4..4bc057416 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryFilter.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryFilter.java @@ -15,6 +15,8 @@ import java.util.List; @NoArgsConstructor public class ChatMemoryFilter { + private Integer agentId; + private String question; private List questions; diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java new file mode 100644 index 000000000..79399f2da --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatMemoryUpdateReq.java @@ -0,0 +1,27 @@ +package com.tencent.supersonic.chat.api.pojo.request; + + +import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; +import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; +import com.tencent.supersonic.common.pojo.RecordInfo; +import lombok.Data; + +import javax.validation.constraints.NotNull; + +@Data +public class ChatMemoryUpdateReq extends RecordInfo { + + @NotNull(message = "id不可为空") + private Long id; + + private String dbSchema; + + private String s2sql; + + private MemoryStatus status; + + private MemoryReviewResult humanReviewRet; + + private String humanReviewCmt; + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java index 963538d49..fea7ee38f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java @@ -6,10 +6,8 @@ import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; -import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; -import lombok.NoArgsConstructor; import lombok.ToString; import java.util.Date; @@ -17,8 +15,6 @@ import java.util.Date; @Data @Builder @ToString -@AllArgsConstructor -@NoArgsConstructor @TableName("s2_chat_memory") public class ChatMemoryDO { @TableId(type = IdType.AUTO) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ChatPlugin.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ChatPlugin.java index 42b0a3938..b54f49a30 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ChatPlugin.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ChatPlugin.java @@ -50,7 +50,7 @@ public class ChatPlugin extends RecordInfo { return Lists.newArrayList(); } - public boolean isContainsAllModel() { + public boolean isContainsAllDataSet() { return CollectionUtils.isNotEmpty(dataSetList) && dataSetList.contains(-1L); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java index d14904200..ad66e8fb7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java @@ -193,26 +193,26 @@ public class PluginManager { public static Pair> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) { SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo(); - Set pluginMatchedModel = getPluginMatchedModel(plugin, chatParseContext); - if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) { + Set pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext); + if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) { return Pair.of(false, Sets.newHashSet()); } List paramOptions = getSemanticOption(plugin); if (CollectionUtils.isEmpty(paramOptions)) { - return Pair.of(true, pluginMatchedModel); + return Pair.of(true, pluginMatchedDataSet); } - Set matchedModel = Sets.newHashSet(); + Set matchedDataSet = Sets.newHashSet(); Map> paramOptionMap = paramOptions.stream() - .collect(Collectors.groupingBy(ParamOption::getModelId)); - for (Long modelId : paramOptionMap.keySet()) { - List params = paramOptionMap.get(modelId); + .collect(Collectors.groupingBy(ParamOption::getDataSetId)); + for (Long dataSetId : paramOptionMap.keySet()) { + List params = paramOptionMap.get(dataSetId); if (CollectionUtils.isEmpty(params)) { - matchedModel.add(modelId); + matchedDataSet.add(dataSetId); continue; } boolean matched = true; for (ParamOption paramOption : params) { - Set elementIdSet = getSchemaElementMatch(modelId, schemaMapInfo); + Set elementIdSet = getSchemaElementMatch(dataSetId, schemaMapInfo); if (CollectionUtils.isEmpty(elementIdSet)) { matched = false; break; @@ -223,13 +223,13 @@ public class PluginManager { } } if (matched) { - matchedModel.add(modelId); + matchedDataSet.add(dataSetId); } } - if (CollectionUtils.isEmpty(matchedModel)) { + if (CollectionUtils.isEmpty(matchedDataSet)) { return Pair.of(false, Sets.newHashSet()); } - return Pair.of(true, matchedModel); + return Pair.of(true, matchedDataSet); } private static Set getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) { @@ -251,7 +251,7 @@ public class PluginManager { return null; } List paramOptions = webBase.getParamOptions(); - if (org.springframework.util.CollectionUtils.isEmpty(paramOptions)) { + if (CollectionUtils.isEmpty(paramOptions)) { return Lists.newArrayList(); } return paramOptions.stream() @@ -259,19 +259,19 @@ public class PluginManager { .collect(Collectors.toList()); } - private static Set getPluginMatchedModel(ChatPlugin plugin, ChatParseContext chatParseContext) { + private static Set getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) { Set matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos(); - if (plugin.isContainsAllModel()) { + if (plugin.isContainsAllDataSet()) { return Sets.newHashSet(plugin.getDefaultMode()); } - List modelIds = plugin.getDataSetList(); - Set pluginMatchedModel = Sets.newHashSet(); - for (Long modelId : modelIds) { - if (matchedDataSets.contains(modelId)) { - pluginMatchedModel.add(modelId); + List dataSetList = plugin.getDataSetList(); + Set pluginMatchedDataSet = Sets.newHashSet(); + for (Long dataSetId : dataSetList) { + if (matchedDataSets.contains(dataSetId)) { + pluginMatchedDataSet.add(dataSetId); } } - return pluginMatchedModel; + return pluginMatchedDataSet; } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/ParamOption.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/ParamOption.java index d85f3331b..a695d0120 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/ParamOption.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/ParamOption.java @@ -15,7 +15,7 @@ public class ParamOption { private String keyAlias; - private Long modelId; + private Long dataSetId; private Long elementId; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java index dd713e287..6ff6321a0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java @@ -77,8 +77,8 @@ public abstract class PluginSemanticQuery { List paramOptions = Lists.newArrayList(); if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) { for (ParamOption paramOption : webPage.getParamOptions()) { - if (paramOption.getModelId() != null - && !parseInfo.getDataSetId().equals(paramOption.getModelId())) { + if (paramOption.getDataSetId() != null + && !parseInfo.getDataSetId().equals(paramOption.getDataSetId())) { continue; } paramOptions.add(paramOption); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java index dacbe9d40..cae4347ac 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java @@ -47,7 +47,7 @@ public abstract class PluginRecognizer { PluginRecallResult pluginRecallResult) { ChatPlugin plugin = pluginRecallResult.getPlugin(); Set dataSetIds = pluginRecallResult.getDataSetIds(); - if (plugin.isContainsAllModel()) { + if (plugin.isContainsAllDataSet()) { dataSetIds = Sets.newHashSet(-1L); } for (Long dataSetId : dataSetIds) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java index 55f20a310..58f02a757 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.rest; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.service.MemoryService; @@ -14,7 +15,6 @@ import org.springframework.web.bind.annotation.RestController; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.util.Date; @RestController @RequestMapping({"/api/chat/memory"}) @@ -23,27 +23,13 @@ public class MemoryController { @Autowired private MemoryService memoryService; - @PostMapping("/createMemory") - public Boolean createMemory(@RequestBody ChatMemoryDO memory, - HttpServletRequest request, - HttpServletResponse response) { - User user = UserHolder.findUser(request, response); - memory.setCreatedBy(user.getName()); - memory.setUpdatedBy(user.getName()); - memory.setCreatedAt(new Date()); - memory.setUpdatedAt(new Date()); - memoryService.createMemory(memory); - return true; - } - @PostMapping("/updateMemory") - public void updateMemory(ChatMemoryDO memory, + public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); - memory.setUpdatedBy(user.getName()); - memory.setUpdatedAt(new Date()); - memoryService.updateMemory(memory); + memoryService.updateMemory(chatMemoryUpdateReq, user); + return true; } @RequestMapping("/pageMemories") diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java index 3e445b01c..8ff65bb73 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java @@ -1,7 +1,9 @@ package com.tencent.supersonic.chat.server.service; import com.github.pagehelper.PageInfo; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; +import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; @@ -10,6 +12,8 @@ import java.util.List; public interface MemoryService { void createMemory(ChatMemoryDO memory); + void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user); + void updateMemory(ChatMemoryDO memory); PageInfo pageMemories(PageMemoryReq pageMemoryReq); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index e9fe0b3e6..62db679e5 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -96,7 +96,11 @@ public class AgentServiceImpl extends ServiceImpl if (memoriesExisted.contains(example)) { continue; } - chatService.parseAndExecute(-1, agent.getId(), example); + try { + chatService.parseAndExecute(-1, agent.getId(), example); + } catch (Exception e) { + log.warn("agent:{} example execute failed:{}", agent.getName(), example); + } } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index 477fe3018..49fb9aa4d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -3,14 +3,17 @@ package com.tencent.supersonic.chat.server.service.impl; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageInfo; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; +import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.service.ExemplarService; +import com.tencent.supersonic.common.util.BeanMapper; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -33,12 +36,21 @@ public class MemoryServiceImpl implements MemoryService { } @Override - public void updateMemory(ChatMemoryDO memory) { - if (MemoryStatus.ENABLED.equals(memory.getStatus())) { - enableMemory(memory); - } else if (MemoryStatus.DISABLED.equals(memory.getStatus())) { - disableMemory(memory); + public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) { + chatMemoryUpdateReq.updatedBy(user.getName()); + ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId()); + boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus()); + BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO); + if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) { + enableMemory(chatMemoryDO); + } else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) { + disableMemory(chatMemoryDO); } + updateMemory(chatMemoryDO); + } + + @Override + public void updateMemory(ChatMemoryDO memory) { chatMemoryRepository.updateMemory(memory); } @@ -52,6 +64,9 @@ public class MemoryServiceImpl implements MemoryService { @Override public List getMemories(ChatMemoryFilter chatMemoryFilter) { QueryWrapper queryWrapper = new QueryWrapper<>(); + if (chatMemoryFilter.getAgentId() != null) { + queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryFilter.getAgentId()); + } if (StringUtils.isNotBlank(chatMemoryFilter.getQuestion())) { queryWrapper.lambda().like(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestion()); }