mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
[improvement][chat] The parser interface supports using the dataSetId provided by the frontend as the reference (#1852)
This commit is contained in:
@@ -16,6 +16,7 @@ public class ChatParseReq {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Integer agentId;
|
||||
private Long dataSetId;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
|
||||
@@ -1,56 +1,61 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.server.memory.MemoryReviewTask;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class Agent extends RecordInfo {
|
||||
|
||||
private static final int ONLINE_STATUS = 1;
|
||||
private static final int OFFLINE_STATUS = 0;
|
||||
private static final int ENABLED = 1;
|
||||
private static final int DISABLED = 0;
|
||||
|
||||
private Integer id;
|
||||
private String name;
|
||||
private String description;
|
||||
/** 0 offline, 1 online */
|
||||
private Integer status = 1;
|
||||
private Integer status = ONLINE_STATUS;
|
||||
private List<String> examples;
|
||||
private Integer enableSearch = 1;
|
||||
private Integer enableFeedback = 0;
|
||||
private Integer enableSearch = ENABLED;
|
||||
private Integer enableFeedback = DISABLED;
|
||||
private String toolConfig;
|
||||
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
|
||||
private Map<String, ChatApp> chatAppConfig = Collections.emptyMap();
|
||||
private VisualConfig visualConfig;
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
Map map = JSONObject.parseObject(toolConfig, Map.class);
|
||||
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
|
||||
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
|
||||
return Lists.newArrayList();
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<Map> toolList = (List) map.get("tools");
|
||||
return toolList.stream().filter(tool -> {
|
||||
if (Objects.isNull(type)) {
|
||||
return true;
|
||||
}
|
||||
return type.name().equals(tool.get("type"));
|
||||
}).map(JSONObject::toJSONString).collect(Collectors.toList());
|
||||
List<Map<String, Object>> toolList = (List<Map<String, Object>>) map.get("tools");
|
||||
return toolList.stream()
|
||||
.filter(tool -> type == null || type.name().equals(tool.get("type")))
|
||||
.map(JSONObject::toJSONString).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public boolean enableSearch() {
|
||||
return enableSearch == 1;
|
||||
return enableSearch == ENABLED;
|
||||
}
|
||||
|
||||
public boolean enableFeedback() {
|
||||
return enableFeedback == 1;
|
||||
return enableFeedback == ENABLED;
|
||||
}
|
||||
|
||||
public boolean enableMemoryReview() {
|
||||
return chatAppConfig.get(MemoryReviewTask.APP_KEY).isEnable();
|
||||
ChatApp memoryReviewApp = chatAppConfig.get(MemoryReviewTask.APP_KEY);
|
||||
return memoryReviewApp != null && memoryReviewApp.isEnable();
|
||||
}
|
||||
|
||||
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||
@@ -60,7 +65,7 @@ public class Agent extends RecordInfo {
|
||||
public List<DatasetTool> getParserTools(AgentToolType agentToolType) {
|
||||
List<String> tools = this.getTools(agentToolType);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class))
|
||||
.collect(Collectors.toList());
|
||||
@@ -75,22 +80,18 @@ public class Agent extends RecordInfo {
|
||||
}
|
||||
|
||||
public boolean containsAnyTool() {
|
||||
Map map = JSONObject.parseObject(toolConfig, Map.class);
|
||||
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
|
||||
if (CollectionUtils.isEmpty(map)) {
|
||||
return false;
|
||||
}
|
||||
List<Map> toolList = (List) map.get("tools");
|
||||
if (CollectionUtils.isEmpty(toolList)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
List<Map<String, Object>> toolList = (List<Map<String, Object>>) map.get("tools");
|
||||
return !CollectionUtils.isEmpty(toolList);
|
||||
}
|
||||
|
||||
public Set<Long> getDataSetIds() {
|
||||
Set<Long> dataSetIds = getDataSetIds(null);
|
||||
if (containsAllModel(dataSetIds)) {
|
||||
return Sets.newHashSet();
|
||||
return Collections.emptySet();
|
||||
}
|
||||
return dataSetIds;
|
||||
}
|
||||
@@ -98,10 +99,10 @@ public class Agent extends RecordInfo {
|
||||
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
|
||||
List<DatasetTool> commonAgentTools = getParserTools(agentToolType);
|
||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||
return new HashSet<>();
|
||||
return Collections.emptySet();
|
||||
}
|
||||
return commonAgentTools.stream().map(DatasetTool::getDataSetIds)
|
||||
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
.filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds))
|
||||
.flatMap(Collection::stream).collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public class QueryReqConverter {
|
||||
|
||||
@@ -12,10 +18,23 @@ public class QueryReqConverter {
|
||||
BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
|
||||
queryNLReq.setText2SQLType(
|
||||
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
|
||||
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
|
||||
queryNLReq.setDataSetIds(getDataSetIds(parseContext));
|
||||
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
|
||||
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
|
||||
|
||||
return queryNLReq;
|
||||
}
|
||||
|
||||
private static Set<Long> getDataSetIds(ParseContext parseContext) {
|
||||
ChatParseReq chatParseReq = parseContext.getRequest();
|
||||
Set<Long> dataSetIds = parseContext.getAgent().getDataSetIds();
|
||||
Long requestDataSetId = chatParseReq.getDataSetId();
|
||||
|
||||
if (Objects.nonNull(requestDataSetId)) {
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return Collections.singleton(requestDataSetId);
|
||||
}
|
||||
dataSetIds.removeIf(dataSetId -> !dataSetId.equals(requestDataSetId));
|
||||
}
|
||||
return dataSetIds;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,8 +82,8 @@ class SqlReplaceHelperTest {
|
||||
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " +
|
||||
"AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' "
|
||||
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
|
||||
replaceSql);
|
||||
|
||||
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
@@ -93,8 +93,8 @@ class SqlReplaceHelperTest {
|
||||
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND " +
|
||||
"歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND "
|
||||
+ "歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
|
||||
replaceSql);
|
||||
|
||||
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
@@ -105,9 +105,9 @@ class SqlReplaceHelperTest {
|
||||
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND " +
|
||||
"歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) " +
|
||||
"AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND "
|
||||
+ "歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) "
|
||||
+ "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
|
||||
replaceSql);
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap3 = new HashMap<>();
|
||||
|
||||
@@ -20,7 +20,7 @@ import java.util.Set;
|
||||
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = chatQueryContext.getRequest().getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -36,7 +36,7 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
|
||||
}
|
||||
|
||||
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
throw new RuntimeException("Not implemented");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user