mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-18 08:17:18 +00:00
(improvement)(headless) Upgrade to the latest version of langchain4j and add support for embedding deletion operation and reset. (#1660)
This commit is contained in:
@@ -18,4 +18,6 @@ public interface EmbeddingService {
|
||||
|
||||
List<RetrieveQueryResult> retrieveQuery(
|
||||
String collectionName, RetrieveQuery retrieveQuery, int num);
|
||||
|
||||
void removeAll();
|
||||
}
|
||||
|
||||
@@ -7,11 +7,11 @@ import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
@@ -20,7 +20,6 @@ import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
@@ -46,9 +45,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
|
||||
@Override
|
||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
|
||||
EmbeddingStore embeddingStore =
|
||||
EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
|
||||
for (TextSegment query : queries) {
|
||||
String question = query.text();
|
||||
try {
|
||||
@@ -101,30 +99,23 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
|
||||
@Override
|
||||
public void deleteQuery(String collectionName, List<TextSegment> queries) {
|
||||
// Not supported yet in Milvus and Chroma
|
||||
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
EmbeddingStore embeddingStore =
|
||||
EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
|
||||
try {
|
||||
if (embeddingStore instanceof InMemoryEmbeddingStore) {
|
||||
InMemoryEmbeddingStore inMemoryEmbeddingStore =
|
||||
(InMemoryEmbeddingStore) embeddingStore;
|
||||
List<String> queryIds =
|
||||
queries.stream()
|
||||
.map(textSegment -> TextSegmentConvert.getQueryId(textSegment))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isNotEmpty(queryIds)) {
|
||||
MetadataFilterBuilder filterBuilder =
|
||||
new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID);
|
||||
Filter filter = filterBuilder.isIn(queryIds);
|
||||
inMemoryEmbeddingStore.removeAll(filter);
|
||||
for (String queryId : queryIds) {
|
||||
cache.put(queryId, false);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw new RuntimeException("Not supported yet.");
|
||||
|
||||
List<String> queryIds =
|
||||
queries.stream()
|
||||
.map(textSegment -> TextSegmentConvert.getQueryId(textSegment))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isNotEmpty(queryIds)) {
|
||||
MetadataFilterBuilder filterBuilder =
|
||||
new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID);
|
||||
Filter filter = filterBuilder.isIn(queryIds);
|
||||
embeddingStore.removeAll(filter);
|
||||
queryIds.stream().forEach(queryId -> cache.put(queryId, false));
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("deleteQuery error,collectionName:{},queries:{}", collectionName, queries);
|
||||
}
|
||||
@@ -149,6 +140,18 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeAll() {
|
||||
BaseEmbeddingStoreFactory factory =
|
||||
(BaseEmbeddingStoreFactory) EmbeddingStoreFactoryProvider.getFactory();
|
||||
Map<String, EmbeddingStore<TextSegment>> collectionNameToStore =
|
||||
factory.getCollectionNameToStore();
|
||||
for (EmbeddingStore<TextSegment> embeddingStore : collectionNameToStore.values()) {
|
||||
embeddingStore.removeAll();
|
||||
}
|
||||
cache.invalidateAll();
|
||||
}
|
||||
|
||||
private RetrieveQueryResult retrieveSingleQuery(
|
||||
String queryText,
|
||||
EmbeddingModel embeddingModel,
|
||||
@@ -193,28 +196,36 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
return retrieval;
|
||||
}
|
||||
|
||||
public static Filter createCombinedFilter(Map<String, Object> map) {
|
||||
if (MapUtils.isEmpty(map)) {
|
||||
public static Filter createCombinedFilter(Map<String, Object> criteriaMap) {
|
||||
if (MapUtils.isEmpty(criteriaMap)) {
|
||||
return null;
|
||||
}
|
||||
Filter result = null;
|
||||
for (Map.Entry<String, Object> entry : map.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
Object value = entry.getValue();
|
||||
Filter orFilter = null;
|
||||
|
||||
if (value instanceof List) {
|
||||
for (String val : (List<String>) value) {
|
||||
IsEqualTo isEqualTo = new IsEqualTo(key, val);
|
||||
orFilter = (orFilter == null) ? isEqualTo : Filter.or(orFilter, isEqualTo);
|
||||
Filter combinedFilter = null;
|
||||
for (Map.Entry<String, Object> entry : criteriaMap.entrySet()) {
|
||||
String fieldName = entry.getKey();
|
||||
Object fieldValue = entry.getValue();
|
||||
Filter fieldFilter = null;
|
||||
if (fieldValue instanceof List) {
|
||||
// Create an OR filter for each value in the list
|
||||
for (String value : (List<String>) fieldValue) {
|
||||
IsEqualTo equalToFilter = new IsEqualTo(fieldName, value);
|
||||
fieldFilter =
|
||||
(fieldFilter == null)
|
||||
? equalToFilter
|
||||
: Filter.or(fieldFilter, equalToFilter);
|
||||
}
|
||||
} else if (value instanceof String) {
|
||||
orFilter = new IsEqualTo(key, value);
|
||||
} else if (fieldValue instanceof String) {
|
||||
// Create a simple equality filter
|
||||
fieldFilter = new IsEqualTo(fieldName, fieldValue);
|
||||
}
|
||||
if (orFilter != null) {
|
||||
result = (result == null) ? orFilter : Filter.and(result, orFilter);
|
||||
// Combine the current field filter with the overall filter using AND logic
|
||||
if (fieldFilter != null) {
|
||||
combinedFilter =
|
||||
(combinedFilter == null)
|
||||
? fieldFilter
|
||||
: Filter.and(combinedFilter, fieldFilter);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
return combinedFilter;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
Metadata.from(
|
||||
JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class));
|
||||
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
||||
TextSegmentConvert.addQueryId(segment, exemplar.getQuestion());
|
||||
|
||||
embeddingService.deleteQuery(collection, Lists.newArrayList(segment));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user