diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java
index 70b2fe71e..93f709182 100644
--- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java
+++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java
@@ -1,15 +1,29 @@
package com.tencent.supersonic.common.util.embedding;
+import static dev.langchain4j.internal.Utils.randomUUID;
+import static java.util.Comparator.comparingDouble;
+
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.spi.ServiceHelper;
+import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
+import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
-import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+import dev.langchain4j.store.embedding.RelevanceScore;
+import dev.langchain4j.store.embedding.inmemory.GsonInMemoryEmbeddingStoreJsonCodec;
+import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodec;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.PriorityQueue;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CopyOnWriteArrayList;
import lombok.extern.slf4j.Slf4j;
/***
@@ -44,7 +58,6 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
addCollection(collectionName);
embeddingStore = collectionNameToStore.get(collectionName);
}
-
}
return embeddingStore;
}
@@ -83,4 +96,148 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
return results;
}
+
+ /**
+ * An {@link EmbeddingStore} that stores embeddings in memory.
+ *
+ * Uses a brute force approach by iterating over all embeddings to find the best matches.
+ * @param The class of the object that has been embedded.
+ * Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
+ * copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
+ * and fix concurrentModificationException in a multi-threaded environment
+ */
+ public static class InMemoryEmbeddingStore implements EmbeddingStore {
+
+ private static class Entry {
+
+ String id;
+ Embedding embedding;
+ Embedded embedded;
+
+ Entry(String id, Embedding embedding, Embedded embedded) {
+ this.id = id;
+ this.embedding = embedding;
+ this.embedded = embedded;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Entry> that = (Entry>) o;
+ return Objects.equals(this.id, that.id)
+ && Objects.equals(this.embedding, that.embedding)
+ && Objects.equals(this.embedded, that.embedded);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(id, embedding, embedded);
+ }
+ }
+
+ private final List> entries = new CopyOnWriteArrayList<>();
+
+ @Override
+ public String add(Embedding embedding) {
+ String id = randomUUID();
+ add(id, embedding);
+ return id;
+ }
+
+ @Override
+ public void add(String id, Embedding embedding) {
+ add(id, embedding, null);
+ }
+
+ @Override
+ public String add(Embedding embedding, Embedded embedded) {
+ String id = randomUUID();
+ add(id, embedding, embedded);
+ return id;
+ }
+
+ public void add(String id, Embedding embedding, Embedded embedded) {
+ entries.add(new Entry<>(id, embedding, embedded));
+ }
+
+ @Override
+ public List addAll(List embeddings) {
+ List ids = new ArrayList<>();
+ for (Embedding embedding : embeddings) {
+ ids.add(add(embedding));
+ }
+ return ids;
+ }
+
+ @Override
+ public List addAll(List embeddings, List embedded) {
+ if (embeddings.size() != embedded.size()) {
+ throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
+ }
+
+ List ids = new ArrayList<>();
+ for (int i = 0; i < embeddings.size(); i++) {
+ ids.add(add(embeddings.get(i), embedded.get(i)));
+ }
+ return ids;
+ }
+
+ @Override
+ public List> findRelevant(Embedding referenceEmbedding, int maxResults,
+ double minScore) {
+
+ Comparator> comparator = comparingDouble(EmbeddingMatch::score);
+ PriorityQueue> matches = new PriorityQueue<>(comparator);
+
+ for (Entry entry : entries) {
+ double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
+ double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
+ if (score >= minScore) {
+ matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
+ if (matches.size() > maxResults) {
+ matches.poll();
+ }
+ }
+ }
+
+ List> result = new ArrayList<>(matches);
+ result.sort(comparator);
+ Collections.reverse(result);
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ InMemoryEmbeddingStore> that = (InMemoryEmbeddingStore>) o;
+ return Objects.equals(this.entries, that.entries);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(entries);
+ }
+
+ private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
+ Collection factories = ServiceHelper.loadFactories(
+ InMemoryEmbeddingStoreJsonCodecFactory.class);
+ for (InMemoryEmbeddingStoreJsonCodecFactory factory : factories) {
+ return factory.create();
+ }
+ // fallback to default
+ return new GsonInMemoryEmbeddingStoreJsonCodec();
+ }
+
+ }
+
}