(improvement)(common) Fix the issue of duplicate inclusion of embedding data (#1359)

This commit is contained in:
lexluo09
2024-07-05 23:09:30 +08:00
committed by GitHub
parent 16c3ff0c30
commit a1ab7ac1c1
2 changed files with 47 additions and 5 deletions

View File

@@ -1,5 +1,7 @@
package com.tencent.supersonic.common.service.impl;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
@@ -15,18 +17,19 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
@@ -38,13 +41,23 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Autowired
private EmbeddingModel embeddingModel;
private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
.maximumSize(10000)
.expireAfterWrite(10, TimeUnit.HOURS)
.build();
@Override
public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
for (TextSegment query : queries) {
String question = query.text();
try {
Embedding embedding = embeddingModel.embed(question).content();
boolean existSegment = existSegment(embeddingStore, query, embedding);
if (existSegment) {
continue;
}
embeddingStore.add(embedding, query);
} catch (Exception e) {
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
@@ -53,8 +66,35 @@ public class EmbeddingServiceImpl implements EmbeddingService {
}
}
private boolean existSegment(EmbeddingStore embeddingStore, TextSegment query, Embedding embedding) {
String queryId = TextSegmentConvert.getQueryId(query);
if (queryId == null) {
return false;
}
// Check cache first
Boolean cachedResult = cache.getIfPresent(queryId);
if (cachedResult != null) {
return cachedResult;
}
Map<String, String> filterCondition = new HashMap<>();
filterCondition.put(TextSegmentConvert.QUERY_ID, queryId);
Filter filter = createCombinedFilter(filterCondition);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding).filter(filter).maxResults(1).build();
EmbeddingSearchResult result = embeddingStore.search(request);
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
boolean exists = CollectionUtils.isNotEmpty(relevant);
cache.put(queryId, exists);
return exists;
}
@Override
public void deleteQuery(String collectionName, List<TextSegment> queries) {
//Not supported yet in Milvus and Chroma
}
@Override

View File

@@ -12,6 +12,7 @@ import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
@@ -45,6 +46,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
String.class, Object.class));
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
TextSegmentConvert.addQueryId(segment, exemplar.getQuestion());
embeddingService.addQuery(collection, Lists.newArrayList(segment));
}