From 7150f19defeab9303c58cae9a7b9a2cad207b0b8 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Fri, 16 Aug 2024 21:31:07 +0800 Subject: [PATCH] (improvement)(chat) Vector retrieval supports filtering by modelId collection during query. (#1576) --- .../execute/MetricRecommendProcessor.java | 12 +-- .../service/impl/EmbeddingServiceImpl.java | 27 +++-- .../store/embedding/RetrieveQuery.java | 2 +- .../chat/knowledge/MetaEmbeddingService.java | 98 +++++++++---------- 4 files changed, 73 insertions(+), 66 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java index 33d964ed7..3bc9092b7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java @@ -1,20 +1,19 @@ package com.tencent.supersonic.chat.server.processor.execute; import com.alibaba.fastjson.JSONObject; +import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.util.ContextUtils; -import dev.langchain4j.store.embedding.Retrieval; -import dev.langchain4j.store.embedding.RetrieveQuery; -import dev.langchain4j.store.embedding.RetrieveQueryResult; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; -import java.util.Objects; +import dev.langchain4j.store.embedding.Retrieval; +import dev.langchain4j.store.embedding.RetrieveQuery; +import dev.langchain4j.store.embedding.RetrieveQueryResult; import org.springframework.util.CollectionUtils; import java.util.Collections; @@ -23,6 +22,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -45,7 +45,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor { return; } List metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName()); - Map filterCondition = new HashMap<>(); + Map filterCondition = new HashMap<>(); filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString()); filterCondition.put("type", SchemaElementType.METRIC.name()); RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames) diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 4752c180d..b75efd5ed 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -75,7 +75,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { if (cachedResult != null) { return cachedResult; } - Map filterCondition = new HashMap<>(); + Map filterCondition = new HashMap<>(); filterCondition.put(TextSegmentConvert.QUERY_ID, queryId); Filter filter = createCombinedFilter(filterCondition); EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() @@ -115,7 +115,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { public List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); - Map filterCondition = retrieveQuery.getFilterCondition(); + Map filterCondition = retrieveQuery.getFilterCondition(); return retrieveQuery.getQueryTextsList().stream() .map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num)) .collect(Collectors.toList()); @@ -124,7 +124,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { private RetrieveQueryResult retrieveSingleQuery(String queryText, EmbeddingModel embeddingModel, EmbeddingStore embeddingStore, - Map filterCondition, + Map filterCondition, int num) { Embedding embeddedText = embeddingModel.embed(queryText).content(); Filter filter = createCombinedFilter(filterCondition); @@ -159,14 +159,27 @@ public class EmbeddingServiceImpl implements EmbeddingService { return retrieval; } - private static Filter createCombinedFilter(Map map) { + public static Filter createCombinedFilter(Map map) { if (MapUtils.isEmpty(map)) { return null; } Filter result = null; - for (Map.Entry entry : map.entrySet()) { - IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue()); - result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo); + for (Map.Entry entry : map.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + Filter orFilter = null; + + if (value instanceof List) { + for (String val : (List) value) { + IsEqualTo isEqualTo = new IsEqualTo(key, val); + orFilter = (orFilter == null) ? isEqualTo : Filter.or(orFilter, isEqualTo); + } + } else if (value instanceof String) { + orFilter = new IsEqualTo(key, value); + } + if (orFilter != null) { + result = (result == null) ? orFilter : Filter.and(result, orFilter); + } } return result; } diff --git a/common/src/main/java/dev/langchain4j/store/embedding/RetrieveQuery.java b/common/src/main/java/dev/langchain4j/store/embedding/RetrieveQuery.java index a46bb6fd6..6ba73cfc5 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/RetrieveQuery.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/RetrieveQuery.java @@ -12,7 +12,7 @@ public class RetrieveQuery { private List queryTextsList; - private Map filterCondition; + private Map filterCondition; private List> queryEmbeddings; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.java index 22ea68a57..9933998f2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.java @@ -8,20 +8,20 @@ import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + @Service @Slf4j public class MetaEmbeddingService { @@ -32,14 +32,16 @@ public class MetaEmbeddingService { private EmbeddingConfig embeddingConfig; public List retrieveQuery(RetrieveQuery retrieveQuery, int num, - Map> modelIdToDataSetIds, Set detectDataSetIds) { + Map> modelIdToDataSetIds, + Set detectDataSetIds) { // dataSetIds->modelIds Set allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds); - if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) { - Map filterCondition = new HashMap<>(); - String modelId = allModels.stream().findFirst().get().toString(); - filterCondition.put("modelId", modelId + DictWordType.NATURE_SPILT); + if (CollectionUtils.isNotEmpty(allModels)) { + Map filterCondition = new HashMap<>(); + filterCondition.put("modelId", allModels.stream() + .map(modelId -> modelId + DictWordType.NATURE_SPILT) + .collect(Collectors.toList())); retrieveQuery.setFilterCondition(filterCondition); } @@ -48,46 +50,38 @@ public class MetaEmbeddingService { if (CollectionUtils.isEmpty(resultList)) { return new ArrayList<>(); } - //filter by modelId - if (CollectionUtils.isEmpty(allModels)) { - return resultList; - } + // Filter and process query results. return resultList.stream() - .map(retrieveQueryResult -> { - List retrievals = retrieveQueryResult.getRetrieval(); - if (CollectionUtils.isEmpty(retrievals)) { - return retrieveQueryResult; - } - //filter by modelId - retrievals.removeIf(retrieval -> { - Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId")); - if (Objects.isNull(modelId)) { - return CollectionUtils.isEmpty(allModels); - } - return !allModels.contains(modelId); - }); - //add dataSetId - retrievals = retrievals.stream().flatMap(retrieval -> { - Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId")); - List dataSetIdsByModelId = modelIdToDataSetIds.get(modelId); - if (!CollectionUtils.isEmpty(dataSetIdsByModelId)) { - Set result = new HashSet<>(); - for (Long dataSetId : dataSetIdsByModelId) { - Retrieval retrievalNew = new Retrieval(); - BeanUtils.copyProperties(retrieval, retrievalNew); - retrievalNew.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE); - result.add(retrievalNew); - } - return result.stream(); - } - Set result = new HashSet<>(); - result.add(retrieval); - return result.stream(); - }).collect(Collectors.toList()); - retrieveQueryResult.setRetrieval(retrievals); - return retrieveQueryResult; - }) - .filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval())) + .map(result -> getRetrieveQueryResult(modelIdToDataSetIds, result)) + .filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval())) .collect(Collectors.toList()); } + + private static RetrieveQueryResult getRetrieveQueryResult(Map> modelIdToDataSetIds, RetrieveQueryResult result) { + List retrievals = result.getRetrieval(); + if (CollectionUtils.isEmpty(retrievals)) { + return result; + } + // Process each Retrieval object. + List updatedRetrievals = retrievals.stream() + .flatMap(retrieval -> { + Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId")); + List dataSetIds = modelIdToDataSetIds.get(modelId); + + if (CollectionUtils.isEmpty(dataSetIds)) { + return Stream.of(retrieval); + } + + return dataSetIds.stream().map(dataSetId -> { + Retrieval newRetrieval = new Retrieval(); + BeanUtils.copyProperties(retrieval, newRetrieval); + newRetrieval.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE); + return newRetrieval; + }); + }) + .collect(Collectors.toList()); + result.setRetrieval(updatedRetrievals); + return result; + } } \ No newline at end of file