mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(improvement)(Chat) QueryFilterMapper obtain viewId from agent (#778)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -65,6 +65,10 @@ public class Agent extends RecordInfo {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Set<Long> getViewIds() {
|
||||||
|
return getViewIds(null);
|
||||||
|
}
|
||||||
|
|
||||||
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
||||||
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||||
|
|||||||
@@ -1,20 +1,22 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -24,29 +26,35 @@ public class QueryFilterMapper implements SchemaMapper {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void map(QueryContext queryContext) {
|
public void map(QueryContext queryContext) {
|
||||||
Long viewId = queryContext.getViewId();
|
Agent agent = queryContext.getAgent();
|
||||||
if (viewId == null || viewId <= 0) {
|
if (agent == null || CollectionUtils.isEmpty(agent.getViewIds())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (Agent.containsAllModel(agent.getViewIds())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Set<Long> viewIds = agent.getViewIds();
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
|
clearOtherSchemaElementMatch(viewIds, schemaMapInfo);
|
||||||
|
for (Long viewId : viewIds) {
|
||||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
||||||
}
|
}
|
||||||
addValueSchemaElementMatch(queryContext, schemaElementMatches);
|
addValueSchemaElementMatch(viewId, queryContext, schemaElementMatches);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
|
||||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
||||||
if (!entry.getKey().equals(modelId)) {
|
if (!viewIds.contains(entry.getKey())) {
|
||||||
entry.getValue().clear();
|
entry.getValue().clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
|
private List<SchemaElementMatch> addValueSchemaElementMatch(Long viewId, QueryContext queryContext,
|
||||||
List<SchemaElementMatch> candidateElementMatches) {
|
List<SchemaElementMatch> candidateElementMatches) {
|
||||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||||
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
@@ -61,7 +69,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
|||||||
.name(String.valueOf(filter.getValue()))
|
.name(String.valueOf(filter.getValue()))
|
||||||
.type(SchemaElementType.VALUE)
|
.type(SchemaElementType.VALUE)
|
||||||
.bizName(filter.getBizName())
|
.bizName(filter.getBizName())
|
||||||
.view(queryContext.getViewId())
|
.view(viewId)
|
||||||
.build();
|
.build();
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(element)
|
.element(element)
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ public class RuleSqlParser implements SemanticParser {
|
|||||||
|
|
||||||
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
|
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
|
||||||
new ContextInheritParser(),
|
new ContextInheritParser(),
|
||||||
new AgentCheckParser(),
|
|
||||||
new TimeRangeParser(),
|
new TimeRangeParser(),
|
||||||
new AggregateTypeParser()
|
new AggregateTypeParser(),
|
||||||
|
new AgentCheckParser()
|
||||||
);
|
);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -77,4 +77,4 @@ logging:
|
|||||||
|
|
||||||
inMemoryEmbeddingStore:
|
inMemoryEmbeddingStore:
|
||||||
persistent:
|
persistent:
|
||||||
path: /tmp
|
path: d://
|
||||||
|
|||||||
Reference in New Issue
Block a user