diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java index a1c24f55f..7a723c56f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java @@ -65,6 +65,10 @@ public class Agent extends RecordInfo { .collect(Collectors.toList()); } + public Set getViewIds() { + return getViewIds(null); + } + public Set getViewIds(AgentToolType agentToolType) { List commonAgentTools = getParserTools(agentToolType); if (CollectionUtils.isEmpty(commonAgentTools)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/QueryFilterMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/QueryFilterMapper.java index 8e054891d..20ba6743c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/QueryFilterMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/QueryFilterMapper.java @@ -1,20 +1,22 @@ package com.tencent.supersonic.chat.core.mapper; import com.google.common.collect.Lists; -import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder; +import com.tencent.supersonic.chat.core.agent.Agent; import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; @Slf4j @@ -24,29 +26,35 @@ public class QueryFilterMapper implements SchemaMapper { @Override public void map(QueryContext queryContext) { - Long viewId = queryContext.getViewId(); - if (viewId == null || viewId <= 0) { + Agent agent = queryContext.getAgent(); + if (agent == null || CollectionUtils.isEmpty(agent.getViewIds())) { return; } - SchemaMapInfo schemaMapInfo = queryContext.getMapInfo(); - clearOtherSchemaElementMatch(viewId, schemaMapInfo); - List schemaElementMatches = schemaMapInfo.getMatchedElements(viewId); - if (schemaElementMatches == null) { - schemaElementMatches = Lists.newArrayList(); - schemaMapInfo.setMatchedElements(viewId, schemaElementMatches); + if (Agent.containsAllModel(agent.getViewIds())) { + return; + } + Set viewIds = agent.getViewIds(); + SchemaMapInfo schemaMapInfo = queryContext.getMapInfo(); + clearOtherSchemaElementMatch(viewIds, schemaMapInfo); + for (Long viewId : viewIds) { + List schemaElementMatches = schemaMapInfo.getMatchedElements(viewId); + if (schemaElementMatches == null) { + schemaElementMatches = Lists.newArrayList(); + schemaMapInfo.setMatchedElements(viewId, schemaElementMatches); + } + addValueSchemaElementMatch(viewId, queryContext, schemaElementMatches); } - addValueSchemaElementMatch(queryContext, schemaElementMatches); } - private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) { + private void clearOtherSchemaElementMatch(Set viewIds, SchemaMapInfo schemaMapInfo) { for (Map.Entry> entry : schemaMapInfo.getViewElementMatches().entrySet()) { - if (!entry.getKey().equals(modelId)) { + if (!viewIds.contains(entry.getKey())) { entry.getValue().clear(); } } } - private List addValueSchemaElementMatch(QueryContext queryContext, + private List addValueSchemaElementMatch(Long viewId, QueryContext queryContext, List candidateElementMatches) { QueryFilters queryFilters = queryContext.getQueryFilters(); if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) { @@ -61,7 +69,7 @@ public class QueryFilterMapper implements SchemaMapper { .name(String.valueOf(filter.getValue())) .type(SchemaElementType.VALUE) .bizName(filter.getBizName()) - .view(queryContext.getViewId()) + .view(viewId) .build(); SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() .element(element) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java index 916f25c0f..48ce387b3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java @@ -20,9 +20,9 @@ public class RuleSqlParser implements SemanticParser { private static List auxiliaryParsers = Arrays.asList( new ContextInheritParser(), - new AgentCheckParser(), new TimeRangeParser(), - new AggregateTypeParser() + new AggregateTypeParser(), + new AgentCheckParser() ); @Override diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index 1fb4f1692..d30eec7c3 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -77,4 +77,4 @@ logging: inMemoryEmbeddingStore: persistent: - path: /tmp + path: d://