(improvement)(headless) Upgrade to the latest version of langchain4j and add support for embedding deletion operation and reset. (#1660)

This commit is contained in:
lexluo09
2024-09-12 18:16:16 +08:00
committed by GitHub
parent 693356e46a
commit 4b1dab8e4a
16 changed files with 13307 additions and 16497 deletions

View File

@@ -18,4 +18,6 @@ public interface EmbeddingService {
List<RetrieveQueryResult> retrieveQuery(
String collectionName, RetrieveQuery retrieveQuery, int num);
void removeAll();
}

View File

@@ -7,11 +7,11 @@ import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
@@ -20,7 +20,6 @@ import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
@@ -46,9 +45,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Override
public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
EmbeddingStore embeddingStore =
EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
for (TextSegment query : queries) {
String question = query.text();
try {
@@ -101,30 +99,23 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Override
public void deleteQuery(String collectionName, List<TextSegment> queries) {
// Not supported yet in Milvus and Chroma
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
EmbeddingStore embeddingStore =
EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
try {
if (embeddingStore instanceof InMemoryEmbeddingStore) {
InMemoryEmbeddingStore inMemoryEmbeddingStore =
(InMemoryEmbeddingStore) embeddingStore;
List<String> queryIds =
queries.stream()
.map(textSegment -> TextSegmentConvert.getQueryId(textSegment))
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(queryIds)) {
MetadataFilterBuilder filterBuilder =
new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID);
Filter filter = filterBuilder.isIn(queryIds);
inMemoryEmbeddingStore.removeAll(filter);
for (String queryId : queryIds) {
cache.put(queryId, false);
}
}
} else {
throw new RuntimeException("Not supported yet.");
List<String> queryIds =
queries.stream()
.map(textSegment -> TextSegmentConvert.getQueryId(textSegment))
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(queryIds)) {
MetadataFilterBuilder filterBuilder =
new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID);
Filter filter = filterBuilder.isIn(queryIds);
embeddingStore.removeAll(filter);
queryIds.stream().forEach(queryId -> cache.put(queryId, false));
}
} catch (Exception e) {
log.error("deleteQuery error,collectionName:{},queries:{}", collectionName, queries);
}
@@ -149,6 +140,18 @@ public class EmbeddingServiceImpl implements EmbeddingService {
.collect(Collectors.toList());
}
@Override
public void removeAll() {
BaseEmbeddingStoreFactory factory =
(BaseEmbeddingStoreFactory) EmbeddingStoreFactoryProvider.getFactory();
Map<String, EmbeddingStore<TextSegment>> collectionNameToStore =
factory.getCollectionNameToStore();
for (EmbeddingStore<TextSegment> embeddingStore : collectionNameToStore.values()) {
embeddingStore.removeAll();
}
cache.invalidateAll();
}
private RetrieveQueryResult retrieveSingleQuery(
String queryText,
EmbeddingModel embeddingModel,
@@ -193,28 +196,36 @@ public class EmbeddingServiceImpl implements EmbeddingService {
return retrieval;
}
public static Filter createCombinedFilter(Map<String, Object> map) {
if (MapUtils.isEmpty(map)) {
public static Filter createCombinedFilter(Map<String, Object> criteriaMap) {
if (MapUtils.isEmpty(criteriaMap)) {
return null;
}
Filter result = null;
for (Map.Entry<String, Object> entry : map.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
Filter orFilter = null;
if (value instanceof List) {
for (String val : (List<String>) value) {
IsEqualTo isEqualTo = new IsEqualTo(key, val);
orFilter = (orFilter == null) ? isEqualTo : Filter.or(orFilter, isEqualTo);
Filter combinedFilter = null;
for (Map.Entry<String, Object> entry : criteriaMap.entrySet()) {
String fieldName = entry.getKey();
Object fieldValue = entry.getValue();
Filter fieldFilter = null;
if (fieldValue instanceof List) {
// Create an OR filter for each value in the list
for (String value : (List<String>) fieldValue) {
IsEqualTo equalToFilter = new IsEqualTo(fieldName, value);
fieldFilter =
(fieldFilter == null)
? equalToFilter
: Filter.or(fieldFilter, equalToFilter);
}
} else if (value instanceof String) {
orFilter = new IsEqualTo(key, value);
} else if (fieldValue instanceof String) {
// Create a simple equality filter
fieldFilter = new IsEqualTo(fieldName, fieldValue);
}
if (orFilter != null) {
result = (result == null) ? orFilter : Filter.and(result, orFilter);
// Combine the current field filter with the overall filter using AND logic
if (fieldFilter != null) {
combinedFilter =
(combinedFilter == null)
? fieldFilter
: Filter.and(combinedFilter, fieldFilter);
}
}
return result;
return combinedFilter;
}
}

View File

@@ -54,6 +54,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
Metadata.from(
JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class));
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
TextSegmentConvert.addQueryId(segment, exemplar.getQuestion());
embeddingService.deleteQuery(collection, Lists.newArrayList(segment));
}