(improvement)(headless) Add deduplication and persistence to InMemoryEmbeddingStore (#1256)

This commit is contained in:
lexluo09
2024-06-27 22:24:49 +08:00
committed by GitHub
parent 9d921dc47f
commit 391c0dccc8
14 changed files with 361 additions and 26 deletions

View File

@@ -7,5 +7,5 @@ import lombok.Setter;
@Setter
class EmbeddingStoreProperties {
private String filePath;
private String persistPath;
}

View File

@@ -1,8 +1,6 @@
package dev.langchain4j.inmemory.spring;
import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
@@ -14,6 +12,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class InMemoryAutoConfig {
@@ -22,8 +22,8 @@ public class InMemoryAutoConfig {
public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q";
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.file-path")
EmbeddingStoreFactory milvusChatModel(Properties properties) {
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
return new InMemoryEmbeddingStoreFactory(properties);
}

View File

@@ -1,18 +1,27 @@
package dev.langchain4j.inmemory.spring;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
@Slf4j
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>();
private Properties properties;
@@ -28,9 +37,63 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
if (Objects.nonNull(embeddingStore)) {
return embeddingStore;
}
embeddingStore = new InMemoryEmbeddingStore();
embeddingStore = reloadFromPersistFile(collectionName);
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
}
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
return embeddingStore;
}
private InMemoryEmbeddingStore<TextSegment> reloadFromPersistFile(String collectionName) {
Path filePath = getPersistPath(collectionName);
if (Objects.isNull(filePath)) {
return null;
}
InMemoryEmbeddingStore<TextSegment> embeddingStore = null;
try {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
log.info("embeddingStore reload from file:{}", filePath);
}
} catch (Exception e) {
log.error("load persistFile error, persistFile:" + filePath, e);
}
return embeddingStore;
}
public synchronized void persistFile() {
if (MapUtils.isEmpty(collectionNameToStore)) {
return;
}
for (Map.Entry<String, InMemoryEmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
Path filePath = getPersistPath(entry.getKey());
if (Objects.isNull(filePath)) {
continue;
}
try {
Path directoryPath = filePath.getParent();
if (!Files.exists(directoryPath)) {
Files.createDirectories(directoryPath);
}
entry.getValue().serializeToFile(filePath);
} catch (Exception e) {
log.error("persistFile error, persistFile:" + filePath, e);
}
}
}
private Path getPersistPath(String collectionName) {
String persistFile = PERSISTENT_FILE_PRE + collectionName;
String persistPath = properties.getEmbeddingStore().getPersistPath();
if (StringUtils.isEmpty(persistPath)) {
return null;
}
return Paths.get(persistPath, persistFile);
}
}