(improvement)(headless) Add deduplication and persistence to InMemoryEmbeddingStore (#1256)

This commit is contained in:
lexluo09
2024-06-27 22:24:49 +08:00
committed by GitHub
parent 9d921dc47f
commit 391c0dccc8
14 changed files with 361 additions and 26 deletions

View File

@@ -7,5 +7,5 @@ import lombok.Setter;
@Setter
class EmbeddingStoreProperties {
private String filePath;
private String persistPath;
}

View File

@@ -1,8 +1,6 @@
package dev.langchain4j.inmemory.spring;
import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
@@ -14,6 +12,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class InMemoryAutoConfig {
@@ -22,8 +22,8 @@ public class InMemoryAutoConfig {
public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q";
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.file-path")
EmbeddingStoreFactory milvusChatModel(Properties properties) {
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
return new InMemoryEmbeddingStoreFactory(properties);
}

View File

@@ -1,18 +1,27 @@
package dev.langchain4j.inmemory.spring;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
@Slf4j
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>();
private Properties properties;
@@ -28,9 +37,63 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
if (Objects.nonNull(embeddingStore)) {
return embeddingStore;
}
embeddingStore = new InMemoryEmbeddingStore();
embeddingStore = reloadFromPersistFile(collectionName);
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
}
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
return embeddingStore;
}
private InMemoryEmbeddingStore<TextSegment> reloadFromPersistFile(String collectionName) {
Path filePath = getPersistPath(collectionName);
if (Objects.isNull(filePath)) {
return null;
}
InMemoryEmbeddingStore<TextSegment> embeddingStore = null;
try {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
log.info("embeddingStore reload from file:{}", filePath);
}
} catch (Exception e) {
log.error("load persistFile error, persistFile:" + filePath, e);
}
return embeddingStore;
}
public synchronized void persistFile() {
if (MapUtils.isEmpty(collectionNameToStore)) {
return;
}
for (Map.Entry<String, InMemoryEmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
Path filePath = getPersistPath(entry.getKey());
if (Objects.isNull(filePath)) {
continue;
}
try {
Path directoryPath = filePath.getParent();
if (!Files.exists(directoryPath)) {
Files.createDirectories(directoryPath);
}
entry.getValue().serializeToFile(filePath);
} catch (Exception e) {
log.error("persistFile error, persistFile:" + filePath, e);
}
}
}
private Path getPersistPath(String collectionName) {
String persistFile = PERSISTENT_FILE_PRE + collectionName;
String persistPath = properties.getEmbeddingStore().getPersistPath();
if (StringUtils.isEmpty(persistPath)) {
return null;
}
return Paths.get(persistPath, persistFile);
}
}

View File

@@ -0,0 +1,250 @@
package dev.langchain4j.store.embedding.inmemory;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.stream.IntStream;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
import static java.util.Comparator.comparingDouble;
import static java.util.stream.Collectors.toList;
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
* <p>
* This store can be persisted using the {@link #serializeToJson()} and {@link #serializeToFile(Path)} methods.
* <p>
* It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link #fromFile(Path)} methods.
*
* @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
*/
public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
public Set<Entry<Embedded>> entries = new CopyOnWriteArraySet<>();
@Override
public String add(Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}
@Override
public void add(String id, Embedding embedding) {
add(id, embedding, null);
}
@Override
public String add(Embedding embedding, Embedded embedded) {
String id = randomUUID();
add(id, embedding, embedded);
return id;
}
public void add(String id, Embedding embedding, Embedded embedded) {
entries.add(new Entry<>(id, embedding, embedded));
}
private List<String> add(List<Entry<Embedded>> newEntries) {
entries.addAll(newEntries);
return newEntries.stream()
.map(entry -> entry.id)
.collect(toList());
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<Entry<Embedded>> newEntries = embeddings.stream()
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding))
.collect(toList());
return add(newEntries);
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
if (embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
}
List<Entry<Embedded>> newEntries = IntStream.range(0, embeddings.size())
.mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(toList());
return add(newEntries);
}
@Override
public void removeAll(Collection<String> ids) {
ensureNotEmpty(ids, "ids");
entries.removeIf(entry -> ids.contains(entry.id));
}
@Override
public void removeAll(Filter filter) {
ensureNotNull(filter, "filter");
entries.removeIf(entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
}
@Override
public void removeAll() {
entries.clear();
}
@Override
public EmbeddingSearchResult<Embedded> search(EmbeddingSearchRequest embeddingSearchRequest) {
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
Filter filter = embeddingSearchRequest.filter();
for (Entry<Embedded> entry : entries) {
if (filter != null && entry.embedded instanceof TextSegment) {
Metadata metadata = ((TextSegment) entry.embedded).metadata();
if (!filter.test(metadata)) {
continue;
}
}
double cosineSimilarity = CosineSimilarity.between(entry.embedding,
embeddingSearchRequest.queryEmbedding());
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= embeddingSearchRequest.minScore()) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
if (matches.size() > embeddingSearchRequest.maxResults()) {
matches.poll();
}
}
}
List<EmbeddingMatch<Embedded>> result = new ArrayList<>(matches);
result.sort(comparator);
Collections.reverse(result);
return new EmbeddingSearchResult<>(result);
}
public String serializeToJson() {
return loadCodec().toJson(this);
}
public void serializeToFile(Path filePath) {
try {
String json = serializeToJson();
Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public void serializeToFile(String filePath) {
serializeToFile(Paths.get(filePath));
}
public static InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
return loadCodec().fromJson(json);
}
public static InMemoryEmbeddingStore<TextSegment> fromFile(Path filePath) {
try {
String json = new String(Files.readAllBytes(filePath));
return fromJson(json);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static InMemoryEmbeddingStore<TextSegment> fromFile(String filePath) {
return fromFile(Paths.get(filePath));
}
private static class Entry<Embedded> {
String id;
Embedding embedding;
Embedded embedded;
Entry(String id, Embedding embedding) {
this(id, embedding, null);
}
Entry(String id, Embedding embedding, Embedded embedded) {
this.id = ensureNotBlank(id, "id");
this.embedding = ensureNotNull(embedding, "embedding");
this.embedded = embedded;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Entry<?> that = (Entry<?>) o;
return Objects.equals(this.id, that.id)
&& Objects.equals(this.embedding, that.embedding)
&& Objects.equals(this.embedded, that.embedded);
}
@Override
public int hashCode() {
return Objects.hash(id, embedding, embedded);
}
}
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
for (InMemoryEmbeddingStoreJsonCodecFactory factory :
loadFactories(InMemoryEmbeddingStoreJsonCodecFactory.class)) {
return factory.create();
}
return new GsonInMemoryEmbeddingStoreJsonCodec();
}
}