(improvement)(chat) Fix the error in Milvus query and add the option to create EmbeddingStore based on caching mode (#1310)

This commit is contained in:
lexluo09
2024-07-01 16:29:43 +08:00
committed by GitHub
parent 37d08007c4
commit 7773442fbf
11 changed files with 489 additions and 104 deletions

View File

@@ -3,8 +3,8 @@ package dev.langchain4j.inmemory.spring;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
@@ -15,15 +15,12 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
@Slf4j
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>();
private Properties properties;
@@ -32,18 +29,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
}
@Override
public synchronized EmbeddingStore create(String collectionName) {
InMemoryEmbeddingStore<TextSegment> embeddingStore = collectionNameToStore.get(collectionName);
if (Objects.nonNull(embeddingStore)) {
return embeddingStore;
}
embeddingStore = reloadFromPersistFile(collectionName);
public synchronized EmbeddingStore createEmbeddingStore(String collectionName) {
InMemoryEmbeddingStore<TextSegment> embeddingStore = reloadFromPersistFile(collectionName);
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
}
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
return embeddingStore;
}
private InMemoryEmbeddingStore<TextSegment> reloadFromPersistFile(String collectionName) {
@@ -67,10 +58,10 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
}
public synchronized void persistFile() {
if (MapUtils.isEmpty(collectionNameToStore)) {
if (MapUtils.isEmpty(super.collectionNameToStore)) {
return;
}
for (Map.Entry<String, InMemoryEmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
Path filePath = getPersistPath(entry.getKey());
if (Objects.isNull(filePath)) {
continue;
@@ -80,7 +71,11 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
if (!Files.exists(directoryPath)) {
Files.createDirectories(directoryPath);
}
entry.getValue().serializeToFile(filePath);
if (entry.getValue() instanceof InMemoryEmbeddingStore) {
InMemoryEmbeddingStore<TextSegment> inMemoryEmbeddingStore =
(InMemoryEmbeddingStore) entry.getValue();
inMemoryEmbeddingStore.serializeToFile(filePath);
}
} catch (Exception e) {
log.error("persistFile error, persistFile:" + filePath, e);
}