mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(Chat) QueryFilterMapper obtain viewId from agent (#778)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -65,6 +65,10 @@ public class Agent extends RecordInfo {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public Set<Long> getViewIds() {
|
||||
return getViewIds(null);
|
||||
}
|
||||
|
||||
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
||||
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@@ -24,29 +26,35 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
Long viewId = queryContext.getViewId();
|
||||
if (viewId == null || viewId <= 0) {
|
||||
Agent agent = queryContext.getAgent();
|
||||
if (agent == null || CollectionUtils.isEmpty(agent.getViewIds())) {
|
||||
return;
|
||||
}
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
||||
if (Agent.containsAllModel(agent.getViewIds())) {
|
||||
return;
|
||||
}
|
||||
Set<Long> viewIds = agent.getViewIds();
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
clearOtherSchemaElementMatch(viewIds, schemaMapInfo);
|
||||
for (Long viewId : viewIds) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
||||
}
|
||||
addValueSchemaElementMatch(viewId, queryContext, schemaElementMatches);
|
||||
}
|
||||
addValueSchemaElementMatch(queryContext, schemaElementMatches);
|
||||
}
|
||||
|
||||
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
||||
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
||||
if (!entry.getKey().equals(modelId)) {
|
||||
if (!viewIds.contains(entry.getKey())) {
|
||||
entry.getValue().clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
|
||||
private List<SchemaElementMatch> addValueSchemaElementMatch(Long viewId, QueryContext queryContext,
|
||||
List<SchemaElementMatch> candidateElementMatches) {
|
||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
@@ -61,7 +69,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.view(queryContext.getViewId())
|
||||
.view(viewId)
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
|
||||
@@ -20,9 +20,9 @@ public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
|
||||
new ContextInheritParser(),
|
||||
new AgentCheckParser(),
|
||||
new TimeRangeParser(),
|
||||
new AggregateTypeParser()
|
||||
new AggregateTypeParser(),
|
||||
new AgentCheckParser()
|
||||
);
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user