From 5d9b1b917e3bb221a53863ddc0b0a11e75f368f2 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Mon, 28 Oct 2024 21:51:11 +0800 Subject: [PATCH] [improvement][chat] The parser interface supports using the dataSetId provided by the frontend as the reference (#1852) --- .../chat/api/pojo/request/ChatParseReq.java | 1 + .../supersonic/chat/server/agent/Agent.java | 63 ++++++++++--------- .../chat/server/rest/ChatQueryController.java | 1 - .../chat/server/util/QueryReqConverter.java | 23 ++++++- .../jsqlparser/SqlReplaceHelperTest.java | 14 ++--- .../chat/mapper/BaseMatchStrategy.java | 4 +- 6 files changed, 63 insertions(+), 43 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java index 8794744b5..6dbdab1fa 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java @@ -16,6 +16,7 @@ public class ChatParseReq { private String queryText; private Integer chatId; private Integer agentId; + private Long dataSetId; private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 98ae8363f..0374f10a3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -1,56 +1,61 @@ package com.tencent.supersonic.chat.server.agent; import com.alibaba.fastjson.JSONObject; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import com.tencent.supersonic.chat.server.memory.MemoryReviewTask; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.RecordInfo; import lombok.Data; import org.springframework.util.CollectionUtils; -import java.util.*; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; @Data public class Agent extends RecordInfo { + private static final int ONLINE_STATUS = 1; + private static final int OFFLINE_STATUS = 0; + private static final int ENABLED = 1; + private static final int DISABLED = 0; + private Integer id; private String name; private String description; /** 0 offline, 1 online */ - private Integer status = 1; + private Integer status = ONLINE_STATUS; private List examples; - private Integer enableSearch = 1; - private Integer enableFeedback = 0; + private Integer enableSearch = ENABLED; + private Integer enableFeedback = DISABLED; private String toolConfig; - private Map chatAppConfig = Collections.EMPTY_MAP; + private Map chatAppConfig = Collections.emptyMap(); private VisualConfig visualConfig; public List getTools(AgentToolType type) { - Map map = JSONObject.parseObject(toolConfig, Map.class); + Map map = JSONObject.parseObject(toolConfig, Map.class); if (CollectionUtils.isEmpty(map) || map.get("tools") == null) { - return Lists.newArrayList(); + return Collections.emptyList(); } - List toolList = (List) map.get("tools"); - return toolList.stream().filter(tool -> { - if (Objects.isNull(type)) { - return true; - } - return type.name().equals(tool.get("type")); - }).map(JSONObject::toJSONString).collect(Collectors.toList()); + List> toolList = (List>) map.get("tools"); + return toolList.stream() + .filter(tool -> type == null || type.name().equals(tool.get("type"))) + .map(JSONObject::toJSONString).collect(Collectors.toList()); } public boolean enableSearch() { - return enableSearch == 1; + return enableSearch == ENABLED; } public boolean enableFeedback() { - return enableFeedback == 1; + return enableFeedback == ENABLED; } public boolean enableMemoryReview() { - return chatAppConfig.get(MemoryReviewTask.APP_KEY).isEnable(); + ChatApp memoryReviewApp = chatAppConfig.get(MemoryReviewTask.APP_KEY); + return memoryReviewApp != null && memoryReviewApp.isEnable(); } public static boolean containsAllModel(Set detectViewIds) { @@ -60,7 +65,7 @@ public class Agent extends RecordInfo { public List getParserTools(AgentToolType agentToolType) { List tools = this.getTools(agentToolType); if (CollectionUtils.isEmpty(tools)) { - return Lists.newArrayList(); + return Collections.emptyList(); } return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class)) .collect(Collectors.toList()); @@ -75,22 +80,18 @@ public class Agent extends RecordInfo { } public boolean containsAnyTool() { - Map map = JSONObject.parseObject(toolConfig, Map.class); + Map map = JSONObject.parseObject(toolConfig, Map.class); if (CollectionUtils.isEmpty(map)) { return false; } - List toolList = (List) map.get("tools"); - if (CollectionUtils.isEmpty(toolList)) { - return false; - } - - return true; + List> toolList = (List>) map.get("tools"); + return !CollectionUtils.isEmpty(toolList); } public Set getDataSetIds() { Set dataSetIds = getDataSetIds(null); if (containsAllModel(dataSetIds)) { - return Sets.newHashSet(); + return Collections.emptySet(); } return dataSetIds; } @@ -98,10 +99,10 @@ public class Agent extends RecordInfo { public Set getDataSetIds(AgentToolType agentToolType) { List commonAgentTools = getParserTools(agentToolType); if (CollectionUtils.isEmpty(commonAgentTools)) { - return new HashSet<>(); + return Collections.emptySet(); } return commonAgentTools.stream().map(DatasetTool::getDataSetIds) - .filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream) - .collect(Collectors.toSet()); + .filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds)) + .flatMap(Collection::stream).collect(Collectors.toSet()); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatQueryController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatQueryController.java index e82352498..d3c3f2010 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatQueryController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatQueryController.java @@ -14,7 +14,6 @@ import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 4ef79a766..569f84f71 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -1,9 +1,15 @@ package com.tencent.supersonic.chat.server.util; +import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; +import org.springframework.util.CollectionUtils; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; public class QueryReqConverter { @@ -12,10 +18,23 @@ public class QueryReqConverter { BeanMapper.mapper(parseContext.getRequest(), queryNLReq); queryNLReq.setText2SQLType( parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE); - queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds()); + queryNLReq.setDataSetIds(getDataSetIds(parseContext)); queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig()); queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse()); - return queryNLReq; } + + private static Set getDataSetIds(ParseContext parseContext) { + ChatParseReq chatParseReq = parseContext.getRequest(); + Set dataSetIds = parseContext.getAgent().getDataSetIds(); + Long requestDataSetId = chatParseReq.getDataSetId(); + + if (Objects.nonNull(requestDataSetId)) { + if (CollectionUtils.isEmpty(dataSetIds)) { + return Collections.singleton(requestDataSetId); + } + dataSetIds.removeIf(dataSetId -> !dataSetId.equals(requestDataSetId)); + } + return dataSetIds; + } } diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java index f4f0c9220..66d51a6d5 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java @@ -82,8 +82,8 @@ class SqlReplaceHelperTest { replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " + - "AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " + + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " @@ -93,8 +93,8 @@ class SqlReplaceHelperTest { replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND " + - "歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND " + + "歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " @@ -105,9 +105,9 @@ class SqlReplaceHelperTest { replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND " + - "歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) " + - "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND " + + "歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) " + + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); Map> filedNameToValueMap3 = new HashMap<>(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index 69150a879..a99879e3c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -20,7 +20,7 @@ import java.util.Set; public abstract class BaseMatchStrategy implements MatchStrategy { @Override public Map> match(ChatQueryContext chatQueryContext, List terms, - Set detectDataSetIds) { + Set detectDataSetIds) { String text = chatQueryContext.getRequest().getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; @@ -36,7 +36,7 @@ public abstract class BaseMatchStrategy implements MatchStr } public List detect(ChatQueryContext chatQueryContext, List terms, - Set detectDataSetIds) { + Set detectDataSetIds) { throw new RuntimeException("Not implemented"); }