[improvement](chat) fix concurrentModificationException in a multi-threaded environment. (#501)

This commit is contained in:
lexluo09
2023-12-13 13:26:01 +08:00
committed by GitHub
parent 340cb2c835
commit 2ad0553f6c

View File

@@ -1,15 +1,29 @@
package com.tencent.supersonic.common.util.embedding;
import static dev.langchain4j.internal.Utils.randomUUID;
import static java.util.Comparator.comparingDouble;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.spi.ServiceHelper;
import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.inmemory.GsonInMemoryEmbeddingStoreJsonCodec;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodec;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import lombok.extern.slf4j.Slf4j;
/***
@@ -44,7 +58,6 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
addCollection(collectionName);
embeddingStore = collectionNameToStore.get(collectionName);
}
}
return embeddingStore;
}
@@ -83,4 +96,148 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
return results;
}
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
* @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
* and fix concurrentModificationException in a multi-threaded environment
*/
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
private static class Entry<Embedded> {
String id;
Embedding embedding;
Embedded embedded;
Entry(String id, Embedding embedding, Embedded embedded) {
this.id = id;
this.embedding = embedding;
this.embedded = embedded;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Entry<?> that = (Entry<?>) o;
return Objects.equals(this.id, that.id)
&& Objects.equals(this.embedding, that.embedding)
&& Objects.equals(this.embedded, that.embedded);
}
@Override
public int hashCode() {
return Objects.hash(id, embedding, embedded);
}
}
private final List<Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
@Override
public String add(Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}
@Override
public void add(String id, Embedding embedding) {
add(id, embedding, null);
}
@Override
public String add(Embedding embedding, Embedded embedded) {
String id = randomUUID();
add(id, embedding, embedded);
return id;
}
public void add(String id, Embedding embedding, Embedded embedded) {
entries.add(new Entry<>(id, embedding, embedded));
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = new ArrayList<>();
for (Embedding embedding : embeddings) {
ids.add(add(embedding));
}
return ids;
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
if (embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
}
List<String> ids = new ArrayList<>();
for (int i = 0; i < embeddings.size(); i++) {
ids.add(add(embeddings.get(i), embedded.get(i)));
}
return ids;
}
@Override
public List<EmbeddingMatch<Embedded>> findRelevant(Embedding referenceEmbedding, int maxResults,
double minScore) {
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
for (Entry<Embedded> entry : entries) {
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= minScore) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
if (matches.size() > maxResults) {
matches.poll();
}
}
}
List<EmbeddingMatch<Embedded>> result = new ArrayList<>(matches);
result.sort(comparator);
Collections.reverse(result);
return result;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
InMemoryEmbeddingStore<?> that = (InMemoryEmbeddingStore<?>) o;
return Objects.equals(this.entries, that.entries);
}
@Override
public int hashCode() {
return Objects.hash(entries);
}
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
Collection<InMemoryEmbeddingStoreJsonCodecFactory> factories = ServiceHelper.loadFactories(
InMemoryEmbeddingStoreJsonCodecFactory.class);
for (InMemoryEmbeddingStoreJsonCodecFactory factory : factories) {
return factory.create();
}
// fallback to default
return new GsonInMemoryEmbeddingStoreJsonCodec();
}
}
}