(improvement)(chat) Vector retrieval supports filtering by modelId collection during query. (#1576)

This commit is contained in:
lexluo09
2024-08-16 21:31:07 +08:00
committed by GitHub
parent 6aff51d394
commit 7150f19def
4 changed files with 73 additions and 66 deletions

View File

@@ -1,20 +1,19 @@
package com.tencent.supersonic.chat.server.processor.execute; package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
import java.util.Objects; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.Collections; import java.util.Collections;
@@ -23,6 +22,7 @@ import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -45,7 +45,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
return; return;
} }
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName()); List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
Map<String, String> filterCondition = new HashMap<>(); Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString()); filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString());
filterCondition.put("type", SchemaElementType.METRIC.name()); filterCondition.put("type", SchemaElementType.METRIC.name());
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames) RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)

View File

@@ -75,7 +75,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
if (cachedResult != null) { if (cachedResult != null) {
return cachedResult; return cachedResult;
} }
Map<String, String> filterCondition = new HashMap<>(); Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put(TextSegmentConvert.QUERY_ID, queryId); filterCondition.put(TextSegmentConvert.QUERY_ID, queryId);
Filter filter = createCombinedFilter(filterCondition); Filter filter = createCombinedFilter(filterCondition);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
@@ -115,7 +115,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName); EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition(); Map<String, Object> filterCondition = retrieveQuery.getFilterCondition();
return retrieveQuery.getQueryTextsList().stream() return retrieveQuery.getQueryTextsList().stream()
.map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num)) .map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num))
.collect(Collectors.toList()); .collect(Collectors.toList());
@@ -124,7 +124,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
private RetrieveQueryResult retrieveSingleQuery(String queryText, private RetrieveQueryResult retrieveSingleQuery(String queryText,
EmbeddingModel embeddingModel, EmbeddingModel embeddingModel,
EmbeddingStore embeddingStore, EmbeddingStore embeddingStore,
Map<String, String> filterCondition, Map<String, Object> filterCondition,
int num) { int num) {
Embedding embeddedText = embeddingModel.embed(queryText).content(); Embedding embeddedText = embeddingModel.embed(queryText).content();
Filter filter = createCombinedFilter(filterCondition); Filter filter = createCombinedFilter(filterCondition);
@@ -159,14 +159,27 @@ public class EmbeddingServiceImpl implements EmbeddingService {
return retrieval; return retrieval;
} }
private static Filter createCombinedFilter(Map<String, String> map) { public static Filter createCombinedFilter(Map<String, Object> map) {
if (MapUtils.isEmpty(map)) { if (MapUtils.isEmpty(map)) {
return null; return null;
} }
Filter result = null; Filter result = null;
for (Map.Entry<String, String> entry : map.entrySet()) { for (Map.Entry<String, Object> entry : map.entrySet()) {
IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue()); String key = entry.getKey();
result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo); Object value = entry.getValue();
Filter orFilter = null;
if (value instanceof List) {
for (String val : (List<String>) value) {
IsEqualTo isEqualTo = new IsEqualTo(key, val);
orFilter = (orFilter == null) ? isEqualTo : Filter.or(orFilter, isEqualTo);
}
} else if (value instanceof String) {
orFilter = new IsEqualTo(key, value);
}
if (orFilter != null) {
result = (result == null) ? orFilter : Filter.and(result, orFilter);
}
} }
return result; return result;
} }

View File

@@ -12,7 +12,7 @@ public class RetrieveQuery {
private List<String> queryTextsList; private List<String> queryTextsList;
private Map<String, String> filterCondition; private Map<String, Object> filterCondition;
private List<List<Double>> queryEmbeddings; private List<List<Double>> queryEmbeddings;

View File

@@ -8,20 +8,20 @@ import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@Service @Service
@Slf4j @Slf4j
public class MetaEmbeddingService { public class MetaEmbeddingService {
@@ -32,14 +32,16 @@ public class MetaEmbeddingService {
private EmbeddingConfig embeddingConfig; private EmbeddingConfig embeddingConfig;
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num, public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) { Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
// dataSetIds->modelIds // dataSetIds->modelIds
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds); Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) { if (CollectionUtils.isNotEmpty(allModels)) {
Map<String, String> filterCondition = new HashMap<>(); Map<String, Object> filterCondition = new HashMap<>();
String modelId = allModels.stream().findFirst().get().toString(); filterCondition.put("modelId", allModels.stream()
filterCondition.put("modelId", modelId + DictWordType.NATURE_SPILT); .map(modelId -> modelId + DictWordType.NATURE_SPILT)
.collect(Collectors.toList()));
retrieveQuery.setFilterCondition(filterCondition); retrieveQuery.setFilterCondition(filterCondition);
} }
@@ -48,46 +50,38 @@ public class MetaEmbeddingService {
if (CollectionUtils.isEmpty(resultList)) { if (CollectionUtils.isEmpty(resultList)) {
return new ArrayList<>(); return new ArrayList<>();
} }
//filter by modelId // Filter and process query results.
if (CollectionUtils.isEmpty(allModels)) {
return resultList;
}
return resultList.stream() return resultList.stream()
.map(retrieveQueryResult -> { .map(result -> getRetrieveQueryResult(modelIdToDataSetIds, result))
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval(); .filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
if (CollectionUtils.isEmpty(retrievals)) {
return retrieveQueryResult;
}
//filter by modelId
retrievals.removeIf(retrieval -> {
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
if (Objects.isNull(modelId)) {
return CollectionUtils.isEmpty(allModels);
}
return !allModels.contains(modelId);
});
//add dataSetId
retrievals = retrievals.stream().flatMap(retrieval -> {
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
List<Long> dataSetIdsByModelId = modelIdToDataSetIds.get(modelId);
if (!CollectionUtils.isEmpty(dataSetIdsByModelId)) {
Set<Retrieval> result = new HashSet<>();
for (Long dataSetId : dataSetIdsByModelId) {
Retrieval retrievalNew = new Retrieval();
BeanUtils.copyProperties(retrieval, retrievalNew);
retrievalNew.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
result.add(retrievalNew);
}
return result.stream();
}
Set<Retrieval> result = new HashSet<>();
result.add(retrieval);
return result.stream();
}).collect(Collectors.toList());
retrieveQueryResult.setRetrieval(retrievals);
return retrieveQueryResult;
})
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private static RetrieveQueryResult getRetrieveQueryResult(Map<Long,
List<Long>> modelIdToDataSetIds, RetrieveQueryResult result) {
List<Retrieval> retrievals = result.getRetrieval();
if (CollectionUtils.isEmpty(retrievals)) {
return result;
}
// Process each Retrieval object.
List<Retrieval> updatedRetrievals = retrievals.stream()
.flatMap(retrieval -> {
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
if (CollectionUtils.isEmpty(dataSetIds)) {
return Stream.of(retrieval);
}
return dataSetIds.stream().map(dataSetId -> {
Retrieval newRetrieval = new Retrieval();
BeanUtils.copyProperties(retrieval, newRetrieval);
newRetrieval.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
return newRetrieval;
});
})
.collect(Collectors.toList());
result.setRetrieval(updatedRetrievals);
return result;
}
} }