diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java index 70b2fe71e..93f709182 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java @@ -1,15 +1,29 @@ package com.tencent.supersonic.common.util.embedding; +import static dev.langchain4j.internal.Utils.randomUUID; +import static java.util.Comparator.comparingDouble; + import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.spi.ServiceHelper; +import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory; +import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.RelevanceScore; +import dev.langchain4j.store.embedding.inmemory.GsonInMemoryEmbeddingStoreJsonCodec; +import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodec; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.PriorityQueue; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; import lombok.extern.slf4j.Slf4j; /*** @@ -44,7 +58,6 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { addCollection(collectionName); embeddingStore = collectionNameToStore.get(collectionName); } - } return embeddingStore; } @@ -83,4 +96,148 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { return results; } + + /** + * An {@link EmbeddingStore} that stores embeddings in memory. + *

+ * Uses a brute force approach by iterating over all embeddings to find the best matches. + * @param The class of the object that has been embedded. + * Typically, it is {@link dev.langchain4j.data.segment.TextSegment}. + * copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore + * and fix concurrentModificationException in a multi-threaded environment + */ + public static class InMemoryEmbeddingStore implements EmbeddingStore { + + private static class Entry { + + String id; + Embedding embedding; + Embedded embedded; + + Entry(String id, Embedding embedding, Embedded embedded) { + this.id = id; + this.embedding = embedding; + this.embedded = embedded; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Entry that = (Entry) o; + return Objects.equals(this.id, that.id) + && Objects.equals(this.embedding, that.embedding) + && Objects.equals(this.embedded, that.embedded); + } + + @Override + public int hashCode() { + return Objects.hash(id, embedding, embedded); + } + } + + private final List> entries = new CopyOnWriteArrayList<>(); + + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + add(id, embedding, null); + } + + @Override + public String add(Embedding embedding, Embedded embedded) { + String id = randomUUID(); + add(id, embedding, embedded); + return id; + } + + public void add(String id, Embedding embedding, Embedded embedded) { + entries.add(new Entry<>(id, embedding, embedded)); + } + + @Override + public List addAll(List embeddings) { + List ids = new ArrayList<>(); + for (Embedding embedding : embeddings) { + ids.add(add(embedding)); + } + return ids; + } + + @Override + public List addAll(List embeddings, List embedded) { + if (embeddings.size() != embedded.size()) { + throw new IllegalArgumentException("The list of embeddings and embedded must have the same size"); + } + + List ids = new ArrayList<>(); + for (int i = 0; i < embeddings.size(); i++) { + ids.add(add(embeddings.get(i), embedded.get(i))); + } + return ids; + } + + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults, + double minScore) { + + Comparator> comparator = comparingDouble(EmbeddingMatch::score); + PriorityQueue> matches = new PriorityQueue<>(comparator); + + for (Entry entry : entries) { + double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding); + double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity); + if (score >= minScore) { + matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded)); + if (matches.size() > maxResults) { + matches.poll(); + } + } + } + + List> result = new ArrayList<>(matches); + result.sort(comparator); + Collections.reverse(result); + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InMemoryEmbeddingStore that = (InMemoryEmbeddingStore) o; + return Objects.equals(this.entries, that.entries); + } + + @Override + public int hashCode() { + return Objects.hash(entries); + } + + private static InMemoryEmbeddingStoreJsonCodec loadCodec() { + Collection factories = ServiceHelper.loadFactories( + InMemoryEmbeddingStoreJsonCodecFactory.class); + for (InMemoryEmbeddingStoreJsonCodecFactory factory : factories) { + return factory.create(); + } + // fallback to default + return new GsonInMemoryEmbeddingStoreJsonCodec(); + } + + } + }