diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java index cf92fa63e..ec2f4ccf5 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java @@ -237,14 +237,22 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { } public void add(String id, Embedding embedding, Embedded embedded) { - entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded)); + if (checkEmbeddingNotExists(embedding)) { + entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded)); + } + } + + private boolean checkEmbeddingNotExists(Embedding embedding) { + return entries.stream().noneMatch(entry -> entry.embedding.equals(embedding)); } @Override public List addAll(List embeddings) { List ids = new ArrayList<>(); for (Embedding embedding : embeddings) { - ids.add(add(embedding)); + if (checkEmbeddingNotExists(embedding)) { + ids.add(add(embedding)); + } } return ids; } @@ -257,7 +265,9 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { List ids = new ArrayList<>(); for (int i = 0; i < embeddings.size(); i++) { - ids.add(add(embeddings.get(i), embedded.get(i))); + if (checkEmbeddingNotExists(embeddings.get(i))) { + ids.add(add(embeddings.get(i), embedded.get(i))); + } } return ids; }