diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 9fa1ed013..ac6a7333e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -16,7 +16,9 @@ import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.TextSegmentConvert; import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -95,6 +97,21 @@ public class EmbeddingServiceImpl implements EmbeddingService { @Override public void deleteQuery(String collectionName, List queries) { //Not supported yet in Milvus and Chroma + EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); + try { + if (embeddingStore instanceof InMemoryEmbeddingStore) { + InMemoryEmbeddingStore inMemoryEmbeddingStore = (InMemoryEmbeddingStore) embeddingStore; + List queryIds = queries.stream().map(textSegment -> TextSegmentConvert.getQueryId(textSegment)) + .filter(Objects::nonNull).collect(Collectors.toList()); + if (CollectionUtils.isNotEmpty(queryIds)) { + MetadataFilterBuilder filterBuilder = new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID); + Filter filter = filterBuilder.isIn(queryIds); + inMemoryEmbeddingStore.removeAll(filter); + } + } + } catch (Exception e) { + log.error("deleteQuery error,collectionName:{},queries:{}", collectionName, queries); + } } @Override