(improvement)(chat) Vector retrieval supports filtering by modelId collection during query. (#1576)

This commit is contained in:
lexluo09
2024-08-16 21:31:07 +08:00
committed by GitHub
parent 6aff51d394
commit 7150f19def
4 changed files with 73 additions and 66 deletions

View File

@@ -75,7 +75,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
if (cachedResult != null) {
return cachedResult;
}
Map<String, String> filterCondition = new HashMap<>();
Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put(TextSegmentConvert.QUERY_ID, queryId);
Filter filter = createCombinedFilter(filterCondition);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
@@ -115,7 +115,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
Map<String, Object> filterCondition = retrieveQuery.getFilterCondition();
return retrieveQuery.getQueryTextsList().stream()
.map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num))
.collect(Collectors.toList());
@@ -124,7 +124,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
private RetrieveQueryResult retrieveSingleQuery(String queryText,
EmbeddingModel embeddingModel,
EmbeddingStore embeddingStore,
Map<String, String> filterCondition,
Map<String, Object> filterCondition,
int num) {
Embedding embeddedText = embeddingModel.embed(queryText).content();
Filter filter = createCombinedFilter(filterCondition);
@@ -159,14 +159,27 @@ public class EmbeddingServiceImpl implements EmbeddingService {
return retrieval;
}
private static Filter createCombinedFilter(Map<String, String> map) {
public static Filter createCombinedFilter(Map<String, Object> map) {
if (MapUtils.isEmpty(map)) {
return null;
}
Filter result = null;
for (Map.Entry<String, String> entry : map.entrySet()) {
IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue());
result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo);
for (Map.Entry<String, Object> entry : map.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
Filter orFilter = null;
if (value instanceof List) {
for (String val : (List<String>) value) {
IsEqualTo isEqualTo = new IsEqualTo(key, val);
orFilter = (orFilter == null) ? isEqualTo : Filter.or(orFilter, isEqualTo);
}
} else if (value instanceof String) {
orFilter = new IsEqualTo(key, value);
}
if (orFilter != null) {
result = (result == null) ? orFilter : Filter.and(result, orFilter);
}
}
return result;
}

View File

@@ -12,7 +12,7 @@ public class RetrieveQuery {
private List<String> queryTextsList;
private Map<String, String> filterCondition;
private Map<String, Object> filterCondition;
private List<List<Double>> queryEmbeddings;