mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
[improvement](chat) Add EmbeddingPersistentTask to Persist and reload InMemoryS2EmbeddingStore (#525)
This commit is contained in:
@@ -32,4 +32,6 @@ public class EmbeddingConfig {
|
|||||||
@Value("${embedding.metric.analyzeQuery.nResult:5}")
|
@Value("${embedding.metric.analyzeQuery.nResult:5}")
|
||||||
private int metricAnalyzeQueryResultNum;
|
private int metricAnalyzeQueryResultNum;
|
||||||
|
|
||||||
|
@Value("${embeddingStore.persistent.path:/tmp}")
|
||||||
|
private String embeddingStorePersistentPath;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
|
import javax.annotation.PreDestroy;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.scheduling.annotation.Scheduled;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
@Component
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingPersistentTask {
|
||||||
|
|
||||||
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
|
|
||||||
|
@PreDestroy
|
||||||
|
public void onShutdown() {
|
||||||
|
embeddingStorePersistentToFile();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void embeddingStorePersistentToFile() {
|
||||||
|
if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) {
|
||||||
|
log.info("start persistentToFile");
|
||||||
|
((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile();
|
||||||
|
log.info("end persistentToFile");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
|
||||||
|
public void executeTask() {
|
||||||
|
embeddingStorePersistentToFile();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,14 +3,13 @@ package com.tencent.supersonic.common.util.embedding;
|
|||||||
import com.google.gson.Gson;
|
import com.google.gson.Gson;
|
||||||
import com.google.gson.reflect.TypeToken;
|
import com.google.gson.reflect.TypeToken;
|
||||||
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
import java.lang.reflect.Type;
|
import java.lang.reflect.Type;
|
||||||
|
|
||||||
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
|
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
|
public InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
|
||||||
Type type = new TypeToken<InMemoryEmbeddingStore<TextSegment>>() {
|
Type type = new TypeToken<InMemoryEmbeddingStore<EmbeddingQuery>>() {
|
||||||
}.getType();
|
}.getType();
|
||||||
return new Gson().fromJson(json, type);
|
return new Gson().fromJson(json, type);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package com.tencent.supersonic.common.util.embedding;
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
|
|
||||||
public interface InMemoryEmbeddingStoreJsonCodec {
|
public interface InMemoryEmbeddingStoreJsonCodec {
|
||||||
InMemoryEmbeddingStore<TextSegment> fromJson(String json);
|
|
||||||
|
InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json);
|
||||||
|
|
||||||
String toJson(InMemoryEmbeddingStore<?> store);
|
String toJson(InMemoryEmbeddingStore<?> store);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import static java.nio.file.StandardOpenOption.CREATE;
|
|||||||
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
|
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
|
||||||
import static java.util.Comparator.comparingDouble;
|
import static java.util.Comparator.comparingDouble;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
@@ -38,12 +38,44 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||||
|
|
||||||
|
public static final String PERSISTENT_FILE_PRE = "InMemory.";
|
||||||
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
|
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
|
||||||
new ConcurrentHashMap<>();
|
new ConcurrentHashMap<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public synchronized void addCollection(String collectionName) {
|
public synchronized void addCollection(String collectionName) {
|
||||||
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
|
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = null;
|
||||||
|
Path filePath = getPersistentPath(collectionName);
|
||||||
|
try {
|
||||||
|
if (Files.exists(filePath)) {
|
||||||
|
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
|
||||||
|
embeddingStore.entries = new CopyOnWriteArrayList<>(embeddingStore.entries);
|
||||||
|
log.info("embeddingStore reload from file:{}", filePath);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("load persistentFile error, persistentFile:" + filePath, e);
|
||||||
|
}
|
||||||
|
if (Objects.isNull(embeddingStore)) {
|
||||||
|
embeddingStore = new InMemoryEmbeddingStore();
|
||||||
|
}
|
||||||
|
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Path getPersistentPath(String collectionName) {
|
||||||
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
|
String persistentFile = PERSISTENT_FILE_PRE + collectionName;
|
||||||
|
return Paths.get(embeddingConfig.getEmbeddingStorePersistentPath(), persistentFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void persistentToFile() {
|
||||||
|
for (Entry<String, InMemoryEmbeddingStore<EmbeddingQuery>> entry : collectionNameToStore.entrySet()) {
|
||||||
|
Path filePath = getPersistentPath(entry.getKey());
|
||||||
|
try {
|
||||||
|
entry.getValue().serializeToFile(filePath);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("persistentToFile error, persistentFile:" + filePath, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -179,7 +211,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
|
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
|
||||||
private final List<InMemoryEmbeddingStore.Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
|
private List<InMemoryEmbeddingStore.Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String add(Embedding embedding) {
|
public String add(Embedding embedding) {
|
||||||
@@ -289,11 +321,11 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
return new GsonInMemoryEmbeddingStoreJsonCodec();
|
return new GsonInMemoryEmbeddingStoreJsonCodec();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
|
public static InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
|
||||||
return CODEC.fromJson(json);
|
return CODEC.fromJson(json);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static InMemoryEmbeddingStore<TextSegment> fromFile(Path filePath) {
|
public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(Path filePath) {
|
||||||
try {
|
try {
|
||||||
String json = new String(Files.readAllBytes(filePath));
|
String json = new String(Files.readAllBytes(filePath));
|
||||||
return fromJson(json);
|
return fromJson(json);
|
||||||
@@ -302,7 +334,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static InMemoryEmbeddingStore<TextSegment> fromFile(String filePath) {
|
public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(String filePath) {
|
||||||
return fromFile(Paths.get(filePath));
|
return fromFile(Paths.get(filePath));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,4 +79,8 @@ s2:
|
|||||||
logging:
|
logging:
|
||||||
level:
|
level:
|
||||||
dev.langchain4j: DEBUG
|
dev.langchain4j: DEBUG
|
||||||
dev.ai4j.openai4j: DEBUG
|
dev.ai4j.openai4j: DEBUG
|
||||||
|
|
||||||
|
inMemoryEmbeddingStore:
|
||||||
|
persistent:
|
||||||
|
path: /tmp
|
||||||
Reference in New Issue
Block a user