mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
[improvement](chat) support filterCondition in InMemoryS2EmbeddingStore (#523)
This commit is contained in:
@@ -20,11 +20,15 @@ import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/***
|
||||
* Implementation of S2EmbeddingStore within the Java process's in-memory.
|
||||
@@ -75,9 +79,11 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
for (String queryText : queryTextsList) {
|
||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, num);
|
||||
int maxResults = getMaxResults(num, filterCondition);
|
||||
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, maxResults);
|
||||
|
||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||
retrieveQueryResult.setQuery(queryText);
|
||||
@@ -87,9 +93,17 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
retrieval.setDistance(embeddingMatch.score());
|
||||
retrieval.setId(embeddingMatch.embeddingId());
|
||||
retrieval.setQuery(embeddingMatch.embedded().getQuery());
|
||||
retrieval.setMetadata(embeddingMatch.embedded().getMetadata());
|
||||
Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();
|
||||
if (filterRetrieval(filterCondition, metadata)) {
|
||||
continue;
|
||||
}
|
||||
retrieval.setMetadata(metadata);
|
||||
retrievals.add(retrieval);
|
||||
}
|
||||
retrievals = retrievals.stream()
|
||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
|
||||
.limit(num)
|
||||
.collect(Collectors.toList());
|
||||
retrieveQueryResult.setRetrieval(retrievals);
|
||||
results.add(retrieveQueryResult);
|
||||
}
|
||||
@@ -97,14 +111,36 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
return results;
|
||||
}
|
||||
|
||||
private int getMaxResults(int num, Map<String, String> filterCondition) {
|
||||
int maxResults = num;
|
||||
if (MapUtils.isNotEmpty(filterCondition)) {
|
||||
maxResults = num * 5;
|
||||
}
|
||||
return maxResults;
|
||||
}
|
||||
|
||||
private boolean filterRetrieval(Map<String, String> filterCondition, Map<String, Object> metadata) {
|
||||
if (MapUtils.isNotEmpty(metadata) && MapUtils.isNotEmpty(filterCondition)) {
|
||||
for (Entry<String, Object> entry : metadata.entrySet()) {
|
||||
String filterValue = filterCondition.get(entry.getKey());
|
||||
if (StringUtils.isNotBlank(filterValue) && !filterValue.equalsIgnoreCase(
|
||||
entry.getValue().toString())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* An {@link EmbeddingStore} that stores embeddings in memory.
|
||||
* <p>
|
||||
* Uses a brute force approach by iterating over all embeddings to find the best matches.
|
||||
*
|
||||
* @param <Embedded> The class of the object that has been embedded.
|
||||
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
|
||||
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
|
||||
* and fix concurrentModificationException in a multi-threaded environment
|
||||
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
|
||||
* and fix concurrentModificationException in a multi-threaded environment
|
||||
*/
|
||||
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user