mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(improvement)(common) Fix the issue of duplicate inclusion of embedding data (#1359)
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
package com.tencent.supersonic.common.service.impl;
|
package com.tencent.supersonic.common.service.impl;
|
||||||
|
|
||||||
|
import com.google.common.cache.Cache;
|
||||||
|
import com.google.common.cache.CacheBuilder;
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
@@ -15,18 +17,19 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
|||||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import dev.langchain4j.store.embedding.filter.Filter;
|
import dev.langchain4j.store.embedding.filter.Filter;
|
||||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.MapUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.MapUtils;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -38,13 +41,23 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingModel embeddingModel;
|
private EmbeddingModel embeddingModel;
|
||||||
|
|
||||||
|
private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
|
||||||
|
.maximumSize(10000)
|
||||||
|
.expireAfterWrite(10, TimeUnit.HOURS)
|
||||||
|
.build();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||||
|
|
||||||
for (TextSegment query : queries) {
|
for (TextSegment query : queries) {
|
||||||
String question = query.text();
|
String question = query.text();
|
||||||
try {
|
try {
|
||||||
Embedding embedding = embeddingModel.embed(question).content();
|
Embedding embedding = embeddingModel.embed(question).content();
|
||||||
|
boolean existSegment = existSegment(embeddingStore, query, embedding);
|
||||||
|
if (existSegment) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
embeddingStore.add(embedding, query);
|
embeddingStore.add(embedding, query);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
|
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
|
||||||
@@ -53,8 +66,35 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean existSegment(EmbeddingStore embeddingStore, TextSegment query, Embedding embedding) {
|
||||||
|
String queryId = TextSegmentConvert.getQueryId(query);
|
||||||
|
if (queryId == null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Check cache first
|
||||||
|
Boolean cachedResult = cache.getIfPresent(queryId);
|
||||||
|
if (cachedResult != null) {
|
||||||
|
return cachedResult;
|
||||||
|
}
|
||||||
|
Map<String, String> filterCondition = new HashMap<>();
|
||||||
|
filterCondition.put(TextSegmentConvert.QUERY_ID, queryId);
|
||||||
|
Filter filter = createCombinedFilter(filterCondition);
|
||||||
|
|
||||||
|
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||||
|
.queryEmbedding(embedding).filter(filter).maxResults(1).build();
|
||||||
|
|
||||||
|
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||||
|
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||||
|
|
||||||
|
boolean exists = CollectionUtils.isNotEmpty(relevant);
|
||||||
|
|
||||||
|
cache.put(queryId, exists);
|
||||||
|
return exists;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void deleteQuery(String collectionName, List<TextSegment> queries) {
|
public void deleteQuery(String collectionName, List<TextSegment> queries) {
|
||||||
|
//Not supported yet in Milvus and Chroma
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import dev.langchain4j.data.document.Metadata;
|
|||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.boot.CommandLineRunner;
|
import org.springframework.boot.CommandLineRunner;
|
||||||
@@ -45,6 +46,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
||||||
String.class, Object.class));
|
String.class, Object.class));
|
||||||
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
||||||
|
TextSegmentConvert.addQueryId(segment, exemplar.getQuestion());
|
||||||
|
|
||||||
embeddingService.addQuery(collection, Lists.newArrayList(segment));
|
embeddingService.addQuery(collection, Lists.newArrayList(segment));
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user