diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 3e19db67d..9fa1ed013 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -1,5 +1,7 @@ package com.tencent.supersonic.common.service.impl; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; import com.tencent.supersonic.common.service.EmbeddingService; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; @@ -15,18 +17,19 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.TextSegmentConvert; import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.MapUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; - import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.MapUtils; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; @Service @Slf4j @@ -38,13 +41,23 @@ public class EmbeddingServiceImpl implements EmbeddingService { @Autowired private EmbeddingModel embeddingModel; + private Cache cache = CacheBuilder.newBuilder() + .maximumSize(10000) + .expireAfterWrite(10, TimeUnit.HOURS) + .build(); + @Override public void addQuery(String collectionName, List queries) { EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); + for (TextSegment query : queries) { String question = query.text(); try { Embedding embedding = embeddingModel.embed(question).content(); + boolean existSegment = existSegment(embeddingStore, query, embedding); + if (existSegment) { + continue; + } embeddingStore.add(embedding, query); } catch (Exception e) { log.error("embeddingModel embed error question: {}, embeddingStore: {}", question, @@ -53,8 +66,35 @@ public class EmbeddingServiceImpl implements EmbeddingService { } } + private boolean existSegment(EmbeddingStore embeddingStore, TextSegment query, Embedding embedding) { + String queryId = TextSegmentConvert.getQueryId(query); + if (queryId == null) { + return false; + } + // Check cache first + Boolean cachedResult = cache.getIfPresent(queryId); + if (cachedResult != null) { + return cachedResult; + } + Map filterCondition = new HashMap<>(); + filterCondition.put(TextSegmentConvert.QUERY_ID, queryId); + Filter filter = createCombinedFilter(filterCondition); + + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embedding).filter(filter).maxResults(1).build(); + + EmbeddingSearchResult result = embeddingStore.search(request); + List> relevant = result.matches(); + + boolean exists = CollectionUtils.isNotEmpty(relevant); + + cache.put(queryId, exists); + return exists; + } + @Override public void deleteQuery(String collectionName, List queries) { + //Not supported yet in Milvus and Chroma } @Override diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java index 1410c24d4..e8adf9b1b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java @@ -12,6 +12,7 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; +import dev.langchain4j.store.embedding.TextSegmentConvert; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.CommandLineRunner; @@ -45,6 +46,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class)); TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata); + TextSegmentConvert.addQueryId(segment, exemplar.getQuestion()); embeddingService.addQuery(collection, Lists.newArrayList(segment)); }