> entry : collectionNameToStore.entrySet()) {
+ Path filePath = getPersistPath(entry.getKey());
+ if (Objects.isNull(filePath)) {
+ continue;
+ }
+ try {
+ Path directoryPath = filePath.getParent();
+ if (!Files.exists(directoryPath)) {
+ Files.createDirectories(directoryPath);
+ }
+ entry.getValue().serializeToFile(filePath);
+ } catch (Exception e) {
+ log.error("persistFile error, persistFile:" + filePath, e);
+ }
+ }
+ }
+
+ private Path getPersistPath(String collectionName) {
+ String persistFile = PERSISTENT_FILE_PRE + collectionName;
+ String persistPath = properties.getEmbeddingStore().getPersistPath();
+ if (StringUtils.isEmpty(persistPath)) {
+ return null;
+ }
+ return Paths.get(persistPath, persistFile);
+ }
+
}
\ No newline at end of file
diff --git a/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java b/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java
new file mode 100644
index 000000000..05c1475d7
--- /dev/null
+++ b/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java
@@ -0,0 +1,250 @@
+package dev.langchain4j.store.embedding.inmemory;
+
+import dev.langchain4j.data.document.Metadata;
+import dev.langchain4j.data.embedding.Embedding;
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
+import dev.langchain4j.store.embedding.CosineSimilarity;
+import dev.langchain4j.store.embedding.EmbeddingMatch;
+import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
+import dev.langchain4j.store.embedding.EmbeddingSearchResult;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+import dev.langchain4j.store.embedding.RelevanceScore;
+import dev.langchain4j.store.embedding.filter.Filter;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.Set;
+import java.util.concurrent.CopyOnWriteArraySet;
+import java.util.stream.IntStream;
+
+import static dev.langchain4j.internal.Utils.randomUUID;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
+import static dev.langchain4j.spi.ServiceHelper.loadFactories;
+import static java.nio.file.StandardOpenOption.CREATE;
+import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
+import static java.util.Comparator.comparingDouble;
+import static java.util.stream.Collectors.toList;
+
+/**
+ * An {@link EmbeddingStore} that stores embeddings in memory.
+ *
+ * Uses a brute force approach by iterating over all embeddings to find the best matches.
+ *
+ * This store can be persisted using the {@link #serializeToJson()} and {@link #serializeToFile(Path)} methods.
+ *
+ * It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link #fromFile(Path)} methods.
+ *
+ * @param The class of the object that has been embedded.
+ * Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
+ */
+public class InMemoryEmbeddingStore implements EmbeddingStore {
+
+ public Set> entries = new CopyOnWriteArraySet<>();
+
+ @Override
+ public String add(Embedding embedding) {
+ String id = randomUUID();
+ add(id, embedding);
+ return id;
+ }
+
+ @Override
+ public void add(String id, Embedding embedding) {
+ add(id, embedding, null);
+ }
+
+ @Override
+ public String add(Embedding embedding, Embedded embedded) {
+ String id = randomUUID();
+ add(id, embedding, embedded);
+ return id;
+ }
+
+ public void add(String id, Embedding embedding, Embedded embedded) {
+ entries.add(new Entry<>(id, embedding, embedded));
+ }
+
+ private List add(List> newEntries) {
+
+ entries.addAll(newEntries);
+
+ return newEntries.stream()
+ .map(entry -> entry.id)
+ .collect(toList());
+ }
+
+ @Override
+ public List addAll(List embeddings) {
+
+ List> newEntries = embeddings.stream()
+ .map(embedding -> new Entry(randomUUID(), embedding))
+ .collect(toList());
+
+ return add(newEntries);
+ }
+
+ @Override
+ public List addAll(List embeddings, List embedded) {
+ if (embeddings.size() != embedded.size()) {
+ throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
+ }
+
+ List> newEntries = IntStream.range(0, embeddings.size())
+ .mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
+ .collect(toList());
+
+ return add(newEntries);
+ }
+
+ @Override
+ public void removeAll(Collection ids) {
+ ensureNotEmpty(ids, "ids");
+
+ entries.removeIf(entry -> ids.contains(entry.id));
+ }
+
+ @Override
+ public void removeAll(Filter filter) {
+ ensureNotNull(filter, "filter");
+
+ entries.removeIf(entry -> {
+ if (entry.embedded instanceof TextSegment) {
+ return filter.test(((TextSegment) entry.embedded).metadata());
+ } else if (entry.embedded == null) {
+ return false;
+ } else {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
+ });
+ }
+
+ @Override
+ public void removeAll() {
+ entries.clear();
+ }
+
+ @Override
+ public EmbeddingSearchResult search(EmbeddingSearchRequest embeddingSearchRequest) {
+
+ Comparator> comparator = comparingDouble(EmbeddingMatch::score);
+ PriorityQueue> matches = new PriorityQueue<>(comparator);
+
+ Filter filter = embeddingSearchRequest.filter();
+
+ for (Entry entry : entries) {
+
+ if (filter != null && entry.embedded instanceof TextSegment) {
+ Metadata metadata = ((TextSegment) entry.embedded).metadata();
+ if (!filter.test(metadata)) {
+ continue;
+ }
+ }
+
+ double cosineSimilarity = CosineSimilarity.between(entry.embedding,
+ embeddingSearchRequest.queryEmbedding());
+ double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
+ if (score >= embeddingSearchRequest.minScore()) {
+ matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
+ if (matches.size() > embeddingSearchRequest.maxResults()) {
+ matches.poll();
+ }
+ }
+ }
+
+ List> result = new ArrayList<>(matches);
+ result.sort(comparator);
+ Collections.reverse(result);
+
+ return new EmbeddingSearchResult<>(result);
+ }
+
+ public String serializeToJson() {
+ return loadCodec().toJson(this);
+ }
+
+ public void serializeToFile(Path filePath) {
+ try {
+ String json = serializeToJson();
+ Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public void serializeToFile(String filePath) {
+ serializeToFile(Paths.get(filePath));
+ }
+
+ public static InMemoryEmbeddingStore fromJson(String json) {
+ return loadCodec().fromJson(json);
+ }
+
+ public static InMemoryEmbeddingStore fromFile(Path filePath) {
+ try {
+ String json = new String(Files.readAllBytes(filePath));
+ return fromJson(json);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static InMemoryEmbeddingStore fromFile(String filePath) {
+ return fromFile(Paths.get(filePath));
+ }
+
+ private static class Entry {
+
+ String id;
+ Embedding embedding;
+ Embedded embedded;
+
+ Entry(String id, Embedding embedding) {
+ this(id, embedding, null);
+ }
+
+ Entry(String id, Embedding embedding, Embedded embedded) {
+ this.id = ensureNotBlank(id, "id");
+ this.embedding = ensureNotNull(embedding, "embedding");
+ this.embedded = embedded;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Entry> that = (Entry>) o;
+ return Objects.equals(this.id, that.id)
+ && Objects.equals(this.embedding, that.embedding)
+ && Objects.equals(this.embedded, that.embedded);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(id, embedding, embedded);
+ }
+ }
+
+ private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
+ for (InMemoryEmbeddingStoreJsonCodecFactory factory :
+ loadFactories(InMemoryEmbeddingStoreJsonCodecFactory.class)) {
+ return factory.create();
+ }
+ return new GsonInMemoryEmbeddingStoreJsonCodec();
+ }
+}
diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java
index 3f9175391..6404b2eef 100644
--- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java
+++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java
@@ -42,7 +42,6 @@ public class MetaEmbeddingListener implements ApplicationListener {
return;
}
sleep();
- embeddingService.addCollection(embeddingConfig.getMetaCollectionName());
if (event.getEventType().equals(EventType.ADD)) {
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments);
} else if (event.getEventType().equals(EventType.DELETE)) {
diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java
index ebfc44465..dc250bad3 100644
--- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java
+++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java
@@ -5,6 +5,8 @@ import com.tencent.supersonic.common.pojo.DataItem;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.headless.server.web.service.DimensionService;
import com.tencent.supersonic.headless.server.web.service.MetricService;
+import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
+import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
@@ -28,28 +30,33 @@ public class EmbeddingTask {
@Autowired
private DimensionService dimensionService;
+ @Autowired
+ private EmbeddingStoreFactory embeddingStoreFactory;
+
@PreDestroy
public void onShutdown() {
- // embeddingStorePersistentToFile();
+ embeddingStorePersistFile();
}
- // private void embeddingStorePersistentToFile() {
- // if (embeddingService instanceof InMemoryEmbeddingService) {
- // log.info("start persistentToFile");
- // ((InMemoryEmbeddingService) embeddingService).persistentToFile();
- // log.info("end persistentToFile");
- // }
- // }
+ private void embeddingStorePersistFile() {
+ if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) {
+ log.info("start persistFile");
+ InMemoryEmbeddingStoreFactory inMemoryFactory =
+ (InMemoryEmbeddingStoreFactory) embeddingStoreFactory;
+ inMemoryFactory.persistFile();
+ log.info("end persistFile");
+ }
+ }
- @Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
- public void executeTask() {
- // embeddingStorePersistentToFile();
+ @Scheduled(cron = "${s2.inMemoryEmbeddingStore.persist.cron:0 0 * * * ?}")
+ public void executePersistFileTask() {
+ embeddingStorePersistFile();
}
/***
* reload meta embedding
*/
- @Scheduled(cron = "${reload.meta.embedding.corn:0 0 */2 * * ?}")
+ @Scheduled(cron = "${s2.reload.meta.embedding.corn:0 0 */2 * * ?}")
public void reloadMetaEmbedding() {
log.info("reload.meta.embedding start");
try {
diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/KnowledgeController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/KnowledgeController.java
index 50f29d1ba..db2a87b2b 100644
--- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/KnowledgeController.java
+++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/KnowledgeController.java
@@ -141,6 +141,12 @@ public class KnowledgeController {
return true;
}
+ @GetMapping("/embedding/persistFile")
+ public Object executePersistFileTask() {
+ embeddingTask.executePersistFileTask();
+ return true;
+ }
+
/**
* queryDictValue-返回字典的数据
*
@@ -161,8 +167,8 @@ public class KnowledgeController {
*/
@PostMapping("/dict/file")
public String queryDictFilePath(@RequestBody @Valid DictValueReq dictValueReq,
- HttpServletRequest request,
- HttpServletResponse response) {
+ HttpServletRequest request,
+ HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return taskService.queryDictFilePath(dictValueReq, user);
}
diff --git a/launchers/standalone/src/main/resources/langchain4j-config.yaml b/launchers/standalone/src/main/resources/langchain4j-config.yaml
index 39086c564..8ff4a1e4d 100644
--- a/launchers/standalone/src/main/resources/langchain4j-config.yaml
+++ b/launchers/standalone/src/main/resources/langchain4j-config.yaml
@@ -13,4 +13,4 @@ langchain4j:
embedding-model:
model-name: bge-small-zh
embedding-store:
- file-path: /tmp
\ No newline at end of file
+ persist-path: /tmp
\ No newline at end of file
diff --git a/launchers/standalone/src/test/resources/langchain4j-config.yaml b/launchers/standalone/src/test/resources/langchain4j-config.yaml
index 39086c564..8ff4a1e4d 100644
--- a/launchers/standalone/src/test/resources/langchain4j-config.yaml
+++ b/launchers/standalone/src/test/resources/langchain4j-config.yaml
@@ -13,4 +13,4 @@ langchain4j:
embedding-model:
model-name: bge-small-zh
embedding-store:
- file-path: /tmp
\ No newline at end of file
+ persist-path: /tmp
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 79bf0c622..2084ffd38 100644
--- a/pom.xml
+++ b/pom.xml
@@ -215,6 +215,11 @@
spring-boot-autoconfigure-processor
${spring.version}
+
+ com.google.code.gson
+ gson
+ ${gson.version}
+