From 0cbfe473dd8a3af98575b489012eceb9b375e4a7 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Fri, 22 Dec 2023 14:49:53 +0800 Subject: [PATCH] (improvement)(chat) Change the storage of InMemoryEmbeddingStore entity to a Set for deduplication. (#564) --- .../embedding/InMemoryS2EmbeddingStore.java | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java index ec2f4ccf5..828816fb0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java @@ -25,8 +25,9 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; import java.util.PriorityQueue; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CopyOnWriteArraySet; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.MapUtils; @@ -49,7 +50,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { try { if (Files.exists(filePath)) { embeddingStore = InMemoryEmbeddingStore.fromFile(filePath); - embeddingStore.entries = new CopyOnWriteArrayList<>(embeddingStore.entries); + embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries); log.info("embeddingStore reload from file:{}", filePath); } } catch (Exception e) { @@ -215,7 +216,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { } private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec(); - private List> entries = new CopyOnWriteArrayList<>(); + private Set> entries = new CopyOnWriteArraySet<>(); @Override public String add(Embedding embedding) { @@ -237,22 +238,14 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { } public void add(String id, Embedding embedding, Embedded embedded) { - if (checkEmbeddingNotExists(embedding)) { - entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded)); - } - } - - private boolean checkEmbeddingNotExists(Embedding embedding) { - return entries.stream().noneMatch(entry -> entry.embedding.equals(embedding)); + entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded)); } @Override public List addAll(List embeddings) { List ids = new ArrayList<>(); for (Embedding embedding : embeddings) { - if (checkEmbeddingNotExists(embedding)) { - ids.add(add(embedding)); - } + ids.add(add(embedding)); } return ids; } @@ -265,9 +258,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { List ids = new ArrayList<>(); for (int i = 0; i < embeddings.size(); i++) { - if (checkEmbeddingNotExists(embeddings.get(i))) { - ids.add(add(embeddings.get(i), embedded.get(i))); - } + ids.add(add(embeddings.get(i), embedded.get(i))); } return ids; }