mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement](chat) fix concurrentModificationException in a multi-threaded environment. (#501)
This commit is contained in:
@@ -1,15 +1,29 @@
|
|||||||
package com.tencent.supersonic.common.util.embedding;
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||||
|
import static java.util.Comparator.comparingDouble;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.spi.ServiceHelper;
|
||||||
|
import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
|
||||||
|
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||||
|
import dev.langchain4j.store.embedding.inmemory.GsonInMemoryEmbeddingStoreJsonCodec;
|
||||||
|
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodec;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.PriorityQueue;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
/***
|
/***
|
||||||
@@ -44,7 +58,6 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
addCollection(collectionName);
|
addCollection(collectionName);
|
||||||
embeddingStore = collectionNameToStore.get(collectionName);
|
embeddingStore = collectionNameToStore.get(collectionName);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return embeddingStore;
|
return embeddingStore;
|
||||||
}
|
}
|
||||||
@@ -83,4 +96,148 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
|
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An {@link EmbeddingStore} that stores embeddings in memory.
|
||||||
|
* <p>
|
||||||
|
* Uses a brute force approach by iterating over all embeddings to find the best matches.
|
||||||
|
* @param <Embedded> The class of the object that has been embedded.
|
||||||
|
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
|
||||||
|
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
|
||||||
|
* and fix concurrentModificationException in a multi-threaded environment
|
||||||
|
*/
|
||||||
|
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
|
||||||
|
|
||||||
|
private static class Entry<Embedded> {
|
||||||
|
|
||||||
|
String id;
|
||||||
|
Embedding embedding;
|
||||||
|
Embedded embedded;
|
||||||
|
|
||||||
|
Entry(String id, Embedding embedding, Embedded embedded) {
|
||||||
|
this.id = id;
|
||||||
|
this.embedding = embedding;
|
||||||
|
this.embedded = embedded;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (o == null || getClass() != o.getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Entry<?> that = (Entry<?>) o;
|
||||||
|
return Objects.equals(this.id, that.id)
|
||||||
|
&& Objects.equals(this.embedding, that.embedding)
|
||||||
|
&& Objects.equals(this.embedded, that.embedded);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(id, embedding, embedded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final List<Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String add(Embedding embedding) {
|
||||||
|
String id = randomUUID();
|
||||||
|
add(id, embedding);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(String id, Embedding embedding) {
|
||||||
|
add(id, embedding, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String add(Embedding embedding, Embedded embedded) {
|
||||||
|
String id = randomUUID();
|
||||||
|
add(id, embedding, embedded);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void add(String id, Embedding embedding, Embedded embedded) {
|
||||||
|
entries.add(new Entry<>(id, embedding, embedded));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> addAll(List<Embedding> embeddings) {
|
||||||
|
List<String> ids = new ArrayList<>();
|
||||||
|
for (Embedding embedding : embeddings) {
|
||||||
|
ids.add(add(embedding));
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
|
||||||
|
if (embeddings.size() != embedded.size()) {
|
||||||
|
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> ids = new ArrayList<>();
|
||||||
|
for (int i = 0; i < embeddings.size(); i++) {
|
||||||
|
ids.add(add(embeddings.get(i), embedded.get(i)));
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<EmbeddingMatch<Embedded>> findRelevant(Embedding referenceEmbedding, int maxResults,
|
||||||
|
double minScore) {
|
||||||
|
|
||||||
|
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
|
||||||
|
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
|
||||||
|
|
||||||
|
for (Entry<Embedded> entry : entries) {
|
||||||
|
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
|
||||||
|
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
||||||
|
if (score >= minScore) {
|
||||||
|
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
|
||||||
|
if (matches.size() > maxResults) {
|
||||||
|
matches.poll();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
List<EmbeddingMatch<Embedded>> result = new ArrayList<>(matches);
|
||||||
|
result.sort(comparator);
|
||||||
|
Collections.reverse(result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (o == null || getClass() != o.getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
InMemoryEmbeddingStore<?> that = (InMemoryEmbeddingStore<?>) o;
|
||||||
|
return Objects.equals(this.entries, that.entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
|
||||||
|
Collection<InMemoryEmbeddingStoreJsonCodecFactory> factories = ServiceHelper.loadFactories(
|
||||||
|
InMemoryEmbeddingStoreJsonCodecFactory.class);
|
||||||
|
for (InMemoryEmbeddingStoreJsonCodecFactory factory : factories) {
|
||||||
|
return factory.create();
|
||||||
|
}
|
||||||
|
// fallback to default
|
||||||
|
return new GsonInMemoryEmbeddingStoreJsonCodec();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user