mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Vector retrieval supports filtering by modelId collection during query. (#1576)
This commit is contained in:
@@ -1,20 +1,19 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.execute;
|
package com.tencent.supersonic.chat.server.processor.execute;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||||
import java.util.Objects;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@@ -23,6 +22,7 @@ import java.util.HashMap;
|
|||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||||
Map<String, String> filterCondition = new HashMap<>();
|
Map<String, Object> filterCondition = new HashMap<>();
|
||||||
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString());
|
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString());
|
||||||
filterCondition.put("type", SchemaElementType.METRIC.name());
|
filterCondition.put("type", SchemaElementType.METRIC.name());
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
if (cachedResult != null) {
|
if (cachedResult != null) {
|
||||||
return cachedResult;
|
return cachedResult;
|
||||||
}
|
}
|
||||||
Map<String, String> filterCondition = new HashMap<>();
|
Map<String, Object> filterCondition = new HashMap<>();
|
||||||
filterCondition.put(TextSegmentConvert.QUERY_ID, queryId);
|
filterCondition.put(TextSegmentConvert.QUERY_ID, queryId);
|
||||||
Filter filter = createCombinedFilter(filterCondition);
|
Filter filter = createCombinedFilter(filterCondition);
|
||||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||||
@@ -115,7 +115,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||||
EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
|
EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
|
||||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
||||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
Map<String, Object> filterCondition = retrieveQuery.getFilterCondition();
|
||||||
return retrieveQuery.getQueryTextsList().stream()
|
return retrieveQuery.getQueryTextsList().stream()
|
||||||
.map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num))
|
.map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
@@ -124,7 +124,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
private RetrieveQueryResult retrieveSingleQuery(String queryText,
|
private RetrieveQueryResult retrieveSingleQuery(String queryText,
|
||||||
EmbeddingModel embeddingModel,
|
EmbeddingModel embeddingModel,
|
||||||
EmbeddingStore embeddingStore,
|
EmbeddingStore embeddingStore,
|
||||||
Map<String, String> filterCondition,
|
Map<String, Object> filterCondition,
|
||||||
int num) {
|
int num) {
|
||||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||||
Filter filter = createCombinedFilter(filterCondition);
|
Filter filter = createCombinedFilter(filterCondition);
|
||||||
@@ -159,14 +159,27 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
return retrieval;
|
return retrieval;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Filter createCombinedFilter(Map<String, String> map) {
|
public static Filter createCombinedFilter(Map<String, Object> map) {
|
||||||
if (MapUtils.isEmpty(map)) {
|
if (MapUtils.isEmpty(map)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
Filter result = null;
|
Filter result = null;
|
||||||
for (Map.Entry<String, String> entry : map.entrySet()) {
|
for (Map.Entry<String, Object> entry : map.entrySet()) {
|
||||||
IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue());
|
String key = entry.getKey();
|
||||||
result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo);
|
Object value = entry.getValue();
|
||||||
|
Filter orFilter = null;
|
||||||
|
|
||||||
|
if (value instanceof List) {
|
||||||
|
for (String val : (List<String>) value) {
|
||||||
|
IsEqualTo isEqualTo = new IsEqualTo(key, val);
|
||||||
|
orFilter = (orFilter == null) ? isEqualTo : Filter.or(orFilter, isEqualTo);
|
||||||
|
}
|
||||||
|
} else if (value instanceof String) {
|
||||||
|
orFilter = new IsEqualTo(key, value);
|
||||||
|
}
|
||||||
|
if (orFilter != null) {
|
||||||
|
result = (result == null) ? orFilter : Filter.and(result, orFilter);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ public class RetrieveQuery {
|
|||||||
|
|
||||||
private List<String> queryTextsList;
|
private List<String> queryTextsList;
|
||||||
|
|
||||||
private Map<String, String> filterCondition;
|
private Map<String, Object> filterCondition;
|
||||||
|
|
||||||
private List<List<Double>> queryEmbeddings;
|
private List<List<Double>> queryEmbeddings;
|
||||||
|
|
||||||
|
|||||||
@@ -8,20 +8,20 @@ import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
|||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MetaEmbeddingService {
|
public class MetaEmbeddingService {
|
||||||
@@ -32,14 +32,16 @@ public class MetaEmbeddingService {
|
|||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
|
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
|
||||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||||
|
Set<Long> detectDataSetIds) {
|
||||||
// dataSetIds->modelIds
|
// dataSetIds->modelIds
|
||||||
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
||||||
|
|
||||||
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
|
if (CollectionUtils.isNotEmpty(allModels)) {
|
||||||
Map<String, String> filterCondition = new HashMap<>();
|
Map<String, Object> filterCondition = new HashMap<>();
|
||||||
String modelId = allModels.stream().findFirst().get().toString();
|
filterCondition.put("modelId", allModels.stream()
|
||||||
filterCondition.put("modelId", modelId + DictWordType.NATURE_SPILT);
|
.map(modelId -> modelId + DictWordType.NATURE_SPILT)
|
||||||
|
.collect(Collectors.toList()));
|
||||||
retrieveQuery.setFilterCondition(filterCondition);
|
retrieveQuery.setFilterCondition(filterCondition);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,46 +50,38 @@ public class MetaEmbeddingService {
|
|||||||
if (CollectionUtils.isEmpty(resultList)) {
|
if (CollectionUtils.isEmpty(resultList)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
//filter by modelId
|
// Filter and process query results.
|
||||||
if (CollectionUtils.isEmpty(allModels)) {
|
|
||||||
return resultList;
|
|
||||||
}
|
|
||||||
return resultList.stream()
|
return resultList.stream()
|
||||||
.map(retrieveQueryResult -> {
|
.map(result -> getRetrieveQueryResult(modelIdToDataSetIds, result))
|
||||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
.filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
|
||||||
if (CollectionUtils.isEmpty(retrievals)) {
|
|
||||||
return retrieveQueryResult;
|
|
||||||
}
|
|
||||||
//filter by modelId
|
|
||||||
retrievals.removeIf(retrieval -> {
|
|
||||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
|
||||||
if (Objects.isNull(modelId)) {
|
|
||||||
return CollectionUtils.isEmpty(allModels);
|
|
||||||
}
|
|
||||||
return !allModels.contains(modelId);
|
|
||||||
});
|
|
||||||
//add dataSetId
|
|
||||||
retrievals = retrievals.stream().flatMap(retrieval -> {
|
|
||||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
|
||||||
List<Long> dataSetIdsByModelId = modelIdToDataSetIds.get(modelId);
|
|
||||||
if (!CollectionUtils.isEmpty(dataSetIdsByModelId)) {
|
|
||||||
Set<Retrieval> result = new HashSet<>();
|
|
||||||
for (Long dataSetId : dataSetIdsByModelId) {
|
|
||||||
Retrieval retrievalNew = new Retrieval();
|
|
||||||
BeanUtils.copyProperties(retrieval, retrievalNew);
|
|
||||||
retrievalNew.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
|
|
||||||
result.add(retrievalNew);
|
|
||||||
}
|
|
||||||
return result.stream();
|
|
||||||
}
|
|
||||||
Set<Retrieval> result = new HashSet<>();
|
|
||||||
result.add(retrieval);
|
|
||||||
return result.stream();
|
|
||||||
}).collect(Collectors.toList());
|
|
||||||
retrieveQueryResult.setRetrieval(retrievals);
|
|
||||||
return retrieveQueryResult;
|
|
||||||
})
|
|
||||||
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static RetrieveQueryResult getRetrieveQueryResult(Map<Long,
|
||||||
|
List<Long>> modelIdToDataSetIds, RetrieveQueryResult result) {
|
||||||
|
List<Retrieval> retrievals = result.getRetrieval();
|
||||||
|
if (CollectionUtils.isEmpty(retrievals)) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
// Process each Retrieval object.
|
||||||
|
List<Retrieval> updatedRetrievals = retrievals.stream()
|
||||||
|
.flatMap(retrieval -> {
|
||||||
|
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||||
|
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||||
|
return Stream.of(retrieval);
|
||||||
|
}
|
||||||
|
|
||||||
|
return dataSetIds.stream().map(dataSetId -> {
|
||||||
|
Retrieval newRetrieval = new Retrieval();
|
||||||
|
BeanUtils.copyProperties(retrieval, newRetrieval);
|
||||||
|
newRetrieval.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
|
||||||
|
return newRetrieval;
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
result.setRetrieval(updatedRetrievals);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user