[improvement](chat) Add EmbeddingPersistentTask to Persist and reload InMemoryS2EmbeddingStore (#525)

This commit is contained in:
lexluo09
2023-12-17 21:53:21 +08:00
committed by GitHub
parent fe75b3e393
commit 5016881ce3
6 changed files with 81 additions and 12 deletions

View File

@@ -32,4 +32,6 @@ public class EmbeddingConfig {
@Value("${embedding.metric.analyzeQuery.nResult:5}") @Value("${embedding.metric.analyzeQuery.nResult:5}")
private int metricAnalyzeQueryResultNum; private int metricAnalyzeQueryResultNum;
@Value("${embeddingStore.persistent.path:/tmp}")
private String embeddingStorePersistentPath;
} }

View File

@@ -0,0 +1,32 @@
package com.tencent.supersonic.common.util.embedding;
import com.tencent.supersonic.common.util.ComponentFactory;
import javax.annotation.PreDestroy;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class EmbeddingPersistentTask {
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@PreDestroy
public void onShutdown() {
embeddingStorePersistentToFile();
}
private void embeddingStorePersistentToFile() {
if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) {
log.info("start persistentToFile");
((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile();
log.info("end persistentToFile");
}
}
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
public void executeTask() {
embeddingStorePersistentToFile();
}
}

View File

@@ -3,14 +3,13 @@ package com.tencent.supersonic.common.util.embedding;
import com.google.gson.Gson; import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken; import com.google.gson.reflect.TypeToken;
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore; import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
import dev.langchain4j.data.segment.TextSegment;
import java.lang.reflect.Type; import java.lang.reflect.Type;
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec { public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
@Override @Override
public InMemoryEmbeddingStore<TextSegment> fromJson(String json) { public InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
Type type = new TypeToken<InMemoryEmbeddingStore<TextSegment>>() { Type type = new TypeToken<InMemoryEmbeddingStore<EmbeddingQuery>>() {
}.getType(); }.getType();
return new Gson().fromJson(json, type); return new Gson().fromJson(json, type);
} }

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.common.util.embedding; package com.tencent.supersonic.common.util.embedding;
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore; import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
import dev.langchain4j.data.segment.TextSegment;
public interface InMemoryEmbeddingStoreJsonCodec { public interface InMemoryEmbeddingStoreJsonCodec {
InMemoryEmbeddingStore<TextSegment> fromJson(String json);
InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json);
String toJson(InMemoryEmbeddingStore<?> store); String toJson(InMemoryEmbeddingStore<?> store);
} }

View File

@@ -5,9 +5,9 @@ import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
import static java.util.Comparator.comparingDouble; import static java.util.Comparator.comparingDouble;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
@@ -38,12 +38,44 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j @Slf4j
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore = private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
new ConcurrentHashMap<>(); new ConcurrentHashMap<>();
@Override @Override
public synchronized void addCollection(String collectionName) { public synchronized void addCollection(String collectionName) {
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore()); InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = null;
Path filePath = getPersistentPath(collectionName);
try {
if (Files.exists(filePath)) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArrayList<>(embeddingStore.entries);
log.info("embeddingStore reload from file:{}", filePath);
}
} catch (Exception e) {
log.error("load persistentFile error, persistentFile:" + filePath, e);
}
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
}
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
}
private Path getPersistentPath(String collectionName) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String persistentFile = PERSISTENT_FILE_PRE + collectionName;
return Paths.get(embeddingConfig.getEmbeddingStorePersistentPath(), persistentFile);
}
public void persistentToFile() {
for (Entry<String, InMemoryEmbeddingStore<EmbeddingQuery>> entry : collectionNameToStore.entrySet()) {
Path filePath = getPersistentPath(entry.getKey());
try {
entry.getValue().serializeToFile(filePath);
} catch (Exception e) {
log.error("persistentToFile error, persistentFile:" + filePath, e);
}
}
} }
@Override @Override
@@ -179,7 +211,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
} }
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec(); private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
private final List<InMemoryEmbeddingStore.Entry<Embedded>> entries = new CopyOnWriteArrayList<>(); private List<InMemoryEmbeddingStore.Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
@Override @Override
public String add(Embedding embedding) { public String add(Embedding embedding) {
@@ -289,11 +321,11 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
return new GsonInMemoryEmbeddingStoreJsonCodec(); return new GsonInMemoryEmbeddingStoreJsonCodec();
} }
public static InMemoryEmbeddingStore<TextSegment> fromJson(String json) { public static InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
return CODEC.fromJson(json); return CODEC.fromJson(json);
} }
public static InMemoryEmbeddingStore<TextSegment> fromFile(Path filePath) { public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(Path filePath) {
try { try {
String json = new String(Files.readAllBytes(filePath)); String json = new String(Files.readAllBytes(filePath));
return fromJson(json); return fromJson(json);
@@ -302,7 +334,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
} }
} }
public static InMemoryEmbeddingStore<TextSegment> fromFile(String filePath) { public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(String filePath) {
return fromFile(Paths.get(filePath)); return fromFile(Paths.get(filePath));
} }
} }

View File

@@ -79,4 +79,8 @@ s2:
logging: logging:
level: level:
dev.langchain4j: DEBUG dev.langchain4j: DEBUG
dev.ai4j.openai4j: DEBUG dev.ai4j.openai4j: DEBUG
inMemoryEmbeddingStore:
persistent:
path: /tmp