mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(improvement)(chat) Vector retrieval supports filtering by modelId collection during query. (#1576)
This commit is contained in:
@@ -75,7 +75,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
if (cachedResult != null) {
|
||||
return cachedResult;
|
||||
}
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
Map<String, Object> filterCondition = new HashMap<>();
|
||||
filterCondition.put(TextSegmentConvert.QUERY_ID, queryId);
|
||||
Filter filter = createCombinedFilter(filterCondition);
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
@@ -115,7 +115,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
Map<String, Object> filterCondition = retrieveQuery.getFilterCondition();
|
||||
return retrieveQuery.getQueryTextsList().stream()
|
||||
.map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num))
|
||||
.collect(Collectors.toList());
|
||||
@@ -124,7 +124,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
private RetrieveQueryResult retrieveSingleQuery(String queryText,
|
||||
EmbeddingModel embeddingModel,
|
||||
EmbeddingStore embeddingStore,
|
||||
Map<String, String> filterCondition,
|
||||
Map<String, Object> filterCondition,
|
||||
int num) {
|
||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||
Filter filter = createCombinedFilter(filterCondition);
|
||||
@@ -159,14 +159,27 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
return retrieval;
|
||||
}
|
||||
|
||||
private static Filter createCombinedFilter(Map<String, String> map) {
|
||||
public static Filter createCombinedFilter(Map<String, Object> map) {
|
||||
if (MapUtils.isEmpty(map)) {
|
||||
return null;
|
||||
}
|
||||
Filter result = null;
|
||||
for (Map.Entry<String, String> entry : map.entrySet()) {
|
||||
IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue());
|
||||
result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo);
|
||||
for (Map.Entry<String, Object> entry : map.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
Object value = entry.getValue();
|
||||
Filter orFilter = null;
|
||||
|
||||
if (value instanceof List) {
|
||||
for (String val : (List<String>) value) {
|
||||
IsEqualTo isEqualTo = new IsEqualTo(key, val);
|
||||
orFilter = (orFilter == null) ? isEqualTo : Filter.or(orFilter, isEqualTo);
|
||||
}
|
||||
} else if (value instanceof String) {
|
||||
orFilter = new IsEqualTo(key, value);
|
||||
}
|
||||
if (orFilter != null) {
|
||||
result = (result == null) ? orFilter : Filter.and(result, orFilter);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ public class RetrieveQuery {
|
||||
|
||||
private List<String> queryTextsList;
|
||||
|
||||
private Map<String, String> filterCondition;
|
||||
private Map<String, Object> filterCondition;
|
||||
|
||||
private List<List<Double>> queryEmbeddings;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user