[improvement][chat] The parser interface supports using the dataSetId provided by the frontend as the reference (#1852)

This commit is contained in:
lexluo09
2024-10-28 21:51:11 +08:00
committed by GitHub
parent e046a55567
commit 5d9b1b917e
6 changed files with 63 additions and 43 deletions

View File

@@ -16,6 +16,7 @@ public class ChatParseReq {
private String queryText; private String queryText;
private Integer chatId; private Integer chatId;
private Integer agentId; private Integer agentId;
private Long dataSetId;
private User user; private User user;
private QueryFilters queryFilters; private QueryFilters queryFilters;
private boolean saveAnswer = true; private boolean saveAnswer = true;

View File

@@ -1,56 +1,61 @@
package com.tencent.supersonic.chat.server.agent; package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.server.memory.MemoryReviewTask; import com.tencent.supersonic.chat.server.memory.MemoryReviewTask;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data; import lombok.Data;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.*; import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Data @Data
public class Agent extends RecordInfo { public class Agent extends RecordInfo {
private static final int ONLINE_STATUS = 1;
private static final int OFFLINE_STATUS = 0;
private static final int ENABLED = 1;
private static final int DISABLED = 0;
private Integer id; private Integer id;
private String name; private String name;
private String description; private String description;
/** 0 offline, 1 online */ /** 0 offline, 1 online */
private Integer status = 1; private Integer status = ONLINE_STATUS;
private List<String> examples; private List<String> examples;
private Integer enableSearch = 1; private Integer enableSearch = ENABLED;
private Integer enableFeedback = 0; private Integer enableFeedback = DISABLED;
private String toolConfig; private String toolConfig;
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP; private Map<String, ChatApp> chatAppConfig = Collections.emptyMap();
private VisualConfig visualConfig; private VisualConfig visualConfig;
public List<String> getTools(AgentToolType type) { public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(toolConfig, Map.class); Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) { if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
return Lists.newArrayList(); return Collections.emptyList();
} }
List<Map> toolList = (List) map.get("tools"); List<Map<String, Object>> toolList = (List<Map<String, Object>>) map.get("tools");
return toolList.stream().filter(tool -> { return toolList.stream()
if (Objects.isNull(type)) { .filter(tool -> type == null || type.name().equals(tool.get("type")))
return true; .map(JSONObject::toJSONString).collect(Collectors.toList());
}
return type.name().equals(tool.get("type"));
}).map(JSONObject::toJSONString).collect(Collectors.toList());
} }
public boolean enableSearch() { public boolean enableSearch() {
return enableSearch == 1; return enableSearch == ENABLED;
} }
public boolean enableFeedback() { public boolean enableFeedback() {
return enableFeedback == 1; return enableFeedback == ENABLED;
} }
public boolean enableMemoryReview() { public boolean enableMemoryReview() {
return chatAppConfig.get(MemoryReviewTask.APP_KEY).isEnable(); ChatApp memoryReviewApp = chatAppConfig.get(MemoryReviewTask.APP_KEY);
return memoryReviewApp != null && memoryReviewApp.isEnable();
} }
public static boolean containsAllModel(Set<Long> detectViewIds) { public static boolean containsAllModel(Set<Long> detectViewIds) {
@@ -60,7 +65,7 @@ public class Agent extends RecordInfo {
public List<DatasetTool> getParserTools(AgentToolType agentToolType) { public List<DatasetTool> getParserTools(AgentToolType agentToolType) {
List<String> tools = this.getTools(agentToolType); List<String> tools = this.getTools(agentToolType);
if (CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList(); return Collections.emptyList();
} }
return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class)) return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class))
.collect(Collectors.toList()); .collect(Collectors.toList());
@@ -75,22 +80,18 @@ public class Agent extends RecordInfo {
} }
public boolean containsAnyTool() { public boolean containsAnyTool() {
Map map = JSONObject.parseObject(toolConfig, Map.class); Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map)) { if (CollectionUtils.isEmpty(map)) {
return false; return false;
} }
List<Map> toolList = (List) map.get("tools"); List<Map<String, Object>> toolList = (List<Map<String, Object>>) map.get("tools");
if (CollectionUtils.isEmpty(toolList)) { return !CollectionUtils.isEmpty(toolList);
return false;
}
return true;
} }
public Set<Long> getDataSetIds() { public Set<Long> getDataSetIds() {
Set<Long> dataSetIds = getDataSetIds(null); Set<Long> dataSetIds = getDataSetIds(null);
if (containsAllModel(dataSetIds)) { if (containsAllModel(dataSetIds)) {
return Sets.newHashSet(); return Collections.emptySet();
} }
return dataSetIds; return dataSetIds;
} }
@@ -98,10 +99,10 @@ public class Agent extends RecordInfo {
public Set<Long> getDataSetIds(AgentToolType agentToolType) { public Set<Long> getDataSetIds(AgentToolType agentToolType) {
List<DatasetTool> commonAgentTools = getParserTools(agentToolType); List<DatasetTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) { if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>(); return Collections.emptySet();
} }
return commonAgentTools.stream().map(DatasetTool::getDataSetIds) return commonAgentTools.stream().map(DatasetTool::getDataSetIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream) .filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds))
.collect(Collectors.toSet()); .flatMap(Collection::stream).collect(Collectors.toSet());
} }
} }

View File

@@ -14,7 +14,6 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;

View File

@@ -1,9 +1,15 @@
package com.tencent.supersonic.chat.server.util; package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.Objects;
import java.util.Set;
public class QueryReqConverter { public class QueryReqConverter {
@@ -12,10 +18,23 @@ public class QueryReqConverter {
BeanMapper.mapper(parseContext.getRequest(), queryNLReq); BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
queryNLReq.setText2SQLType( queryNLReq.setText2SQLType(
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE); parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds()); queryNLReq.setDataSetIds(getDataSetIds(parseContext));
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig()); queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse()); queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
return queryNLReq; return queryNLReq;
} }
private static Set<Long> getDataSetIds(ParseContext parseContext) {
ChatParseReq chatParseReq = parseContext.getRequest();
Set<Long> dataSetIds = parseContext.getAgent().getDataSetIds();
Long requestDataSetId = chatParseReq.getDataSetId();
if (Objects.nonNull(requestDataSetId)) {
if (CollectionUtils.isEmpty(dataSetIds)) {
return Collections.singleton(requestDataSetId);
}
dataSetIds.removeIf(dataSetId -> !dataSetId.equals(requestDataSetId));
}
return dataSetIds;
}
} }

View File

@@ -82,8 +82,8 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
Assert.assertEquals( Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' "
"AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
replaceSql); replaceSql);
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
@@ -93,8 +93,8 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
Assert.assertEquals( Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND " + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND "
"歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", + "歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
replaceSql); replaceSql);
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
@@ -105,9 +105,9 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
Assert.assertEquals( Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND " + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND "
"歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) " + + "歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) "
"AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
replaceSql); replaceSql);
Map<String, Map<String, String>> filedNameToValueMap3 = new HashMap<>(); Map<String, Map<String, String>> filedNameToValueMap3 = new HashMap<>();

View File

@@ -20,7 +20,7 @@ import java.util.Set;
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> { public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
@Override @Override
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms, public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) { Set<Long> detectDataSetIds) {
String text = chatQueryContext.getRequest().getQueryText(); String text = chatQueryContext.getRequest().getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null; return null;
@@ -36,7 +36,7 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
} }
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms, public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) { Set<Long> detectDataSetIds) {
throw new RuntimeException("Not implemented"); throw new RuntimeException("Not implemented");
} }