mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(improvement)(chat) Vector retrieval supports filtering by modelId collection during query. (#1576)
This commit is contained in:
@@ -8,20 +8,20 @@ import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class MetaEmbeddingService {
|
||||
@@ -32,14 +32,16 @@ public class MetaEmbeddingService {
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
// dataSetIds->modelIds
|
||||
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
String modelId = allModels.stream().findFirst().get().toString();
|
||||
filterCondition.put("modelId", modelId + DictWordType.NATURE_SPILT);
|
||||
if (CollectionUtils.isNotEmpty(allModels)) {
|
||||
Map<String, Object> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", allModels.stream()
|
||||
.map(modelId -> modelId + DictWordType.NATURE_SPILT)
|
||||
.collect(Collectors.toList()));
|
||||
retrieveQuery.setFilterCondition(filterCondition);
|
||||
}
|
||||
|
||||
@@ -48,46 +50,38 @@ public class MetaEmbeddingService {
|
||||
if (CollectionUtils.isEmpty(resultList)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
//filter by modelId
|
||||
if (CollectionUtils.isEmpty(allModels)) {
|
||||
return resultList;
|
||||
}
|
||||
// Filter and process query results.
|
||||
return resultList.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isEmpty(retrievals)) {
|
||||
return retrieveQueryResult;
|
||||
}
|
||||
//filter by modelId
|
||||
retrievals.removeIf(retrieval -> {
|
||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||
if (Objects.isNull(modelId)) {
|
||||
return CollectionUtils.isEmpty(allModels);
|
||||
}
|
||||
return !allModels.contains(modelId);
|
||||
});
|
||||
//add dataSetId
|
||||
retrievals = retrievals.stream().flatMap(retrieval -> {
|
||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||
List<Long> dataSetIdsByModelId = modelIdToDataSetIds.get(modelId);
|
||||
if (!CollectionUtils.isEmpty(dataSetIdsByModelId)) {
|
||||
Set<Retrieval> result = new HashSet<>();
|
||||
for (Long dataSetId : dataSetIdsByModelId) {
|
||||
Retrieval retrievalNew = new Retrieval();
|
||||
BeanUtils.copyProperties(retrieval, retrievalNew);
|
||||
retrievalNew.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
|
||||
result.add(retrievalNew);
|
||||
}
|
||||
return result.stream();
|
||||
}
|
||||
Set<Retrieval> result = new HashSet<>();
|
||||
result.add(retrieval);
|
||||
return result.stream();
|
||||
}).collect(Collectors.toList());
|
||||
retrieveQueryResult.setRetrieval(retrievals);
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||
.map(result -> getRetrieveQueryResult(modelIdToDataSetIds, result))
|
||||
.filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static RetrieveQueryResult getRetrieveQueryResult(Map<Long,
|
||||
List<Long>> modelIdToDataSetIds, RetrieveQueryResult result) {
|
||||
List<Retrieval> retrievals = result.getRetrieval();
|
||||
if (CollectionUtils.isEmpty(retrievals)) {
|
||||
return result;
|
||||
}
|
||||
// Process each Retrieval object.
|
||||
List<Retrieval> updatedRetrievals = retrievals.stream()
|
||||
.flatMap(retrieval -> {
|
||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
|
||||
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return Stream.of(retrieval);
|
||||
}
|
||||
|
||||
return dataSetIds.stream().map(dataSetId -> {
|
||||
Retrieval newRetrieval = new Retrieval();
|
||||
BeanUtils.copyProperties(retrieval, newRetrieval);
|
||||
newRetrieval.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
|
||||
return newRetrieval;
|
||||
});
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
result.setRetrieval(updatedRetrievals);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user