[improvement][chat] The parser interface supports using the dataSetId provided by the frontend as the reference (#1852)

This commit is contained in:
lexluo09
2024-10-28 21:51:11 +08:00
committed by GitHub
parent e046a55567
commit 5d9b1b917e
6 changed files with 63 additions and 43 deletions

View File

@@ -16,6 +16,7 @@ public class ChatParseReq {
private String queryText;
private Integer chatId;
private Integer agentId;
private Long dataSetId;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;

View File

@@ -1,56 +1,61 @@
package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.server.memory.MemoryReviewTask;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@Data
public class Agent extends RecordInfo {
private static final int ONLINE_STATUS = 1;
private static final int OFFLINE_STATUS = 0;
private static final int ENABLED = 1;
private static final int DISABLED = 0;
private Integer id;
private String name;
private String description;
/** 0 offline, 1 online */
private Integer status = 1;
private Integer status = ONLINE_STATUS;
private List<String> examples;
private Integer enableSearch = 1;
private Integer enableFeedback = 0;
private Integer enableSearch = ENABLED;
private Integer enableFeedback = DISABLED;
private String toolConfig;
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
private Map<String, ChatApp> chatAppConfig = Collections.emptyMap();
private VisualConfig visualConfig;
public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(toolConfig, Map.class);
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
return Lists.newArrayList();
return Collections.emptyList();
}
List<Map> toolList = (List) map.get("tools");
return toolList.stream().filter(tool -> {
if (Objects.isNull(type)) {
return true;
}
return type.name().equals(tool.get("type"));
}).map(JSONObject::toJSONString).collect(Collectors.toList());
List<Map<String, Object>> toolList = (List<Map<String, Object>>) map.get("tools");
return toolList.stream()
.filter(tool -> type == null || type.name().equals(tool.get("type")))
.map(JSONObject::toJSONString).collect(Collectors.toList());
}
public boolean enableSearch() {
return enableSearch == 1;
return enableSearch == ENABLED;
}
public boolean enableFeedback() {
return enableFeedback == 1;
return enableFeedback == ENABLED;
}
public boolean enableMemoryReview() {
return chatAppConfig.get(MemoryReviewTask.APP_KEY).isEnable();
ChatApp memoryReviewApp = chatAppConfig.get(MemoryReviewTask.APP_KEY);
return memoryReviewApp != null && memoryReviewApp.isEnable();
}
public static boolean containsAllModel(Set<Long> detectViewIds) {
@@ -60,7 +65,7 @@ public class Agent extends RecordInfo {
public List<DatasetTool> getParserTools(AgentToolType agentToolType) {
List<String> tools = this.getTools(agentToolType);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
return Collections.emptyList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class))
.collect(Collectors.toList());
@@ -75,22 +80,18 @@ public class Agent extends RecordInfo {
}
public boolean containsAnyTool() {
Map map = JSONObject.parseObject(toolConfig, Map.class);
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map)) {
return false;
}
List<Map> toolList = (List) map.get("tools");
if (CollectionUtils.isEmpty(toolList)) {
return false;
}
return true;
List<Map<String, Object>> toolList = (List<Map<String, Object>>) map.get("tools");
return !CollectionUtils.isEmpty(toolList);
}
public Set<Long> getDataSetIds() {
Set<Long> dataSetIds = getDataSetIds(null);
if (containsAllModel(dataSetIds)) {
return Sets.newHashSet();
return Collections.emptySet();
}
return dataSetIds;
}
@@ -98,10 +99,10 @@ public class Agent extends RecordInfo {
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
List<DatasetTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
return Collections.emptySet();
}
return commonAgentTools.stream().map(DatasetTool::getDataSetIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream)
.collect(Collectors.toSet());
.filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds))
.flatMap(Collection::stream).collect(Collectors.toSet());
}
}

View File

@@ -14,7 +14,6 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;

View File

@@ -1,9 +1,15 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.Objects;
import java.util.Set;
public class QueryReqConverter {
@@ -12,10 +18,23 @@ public class QueryReqConverter {
BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
queryNLReq.setText2SQLType(
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
queryNLReq.setDataSetIds(getDataSetIds(parseContext));
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
return queryNLReq;
}
private static Set<Long> getDataSetIds(ParseContext parseContext) {
ChatParseReq chatParseReq = parseContext.getRequest();
Set<Long> dataSetIds = parseContext.getAgent().getDataSetIds();
Long requestDataSetId = chatParseReq.getDataSetId();
if (Objects.nonNull(requestDataSetId)) {
if (CollectionUtils.isEmpty(dataSetIds)) {
return Collections.singleton(requestDataSetId);
}
dataSetIds.removeIf(dataSetId -> !dataSetId.equals(requestDataSetId));
}
return dataSetIds;
}
}

View File

@@ -82,8 +82,8 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " +
"AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' "
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
@@ -93,8 +93,8 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND " +
"歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND "
+ "歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
@@ -105,9 +105,9 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND " +
"歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) " +
"AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND "
+ "歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) "
+ "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
Map<String, Map<String, String>> filedNameToValueMap3 = new HashMap<>();