diff --git a/.github/ISSUE_TEMPLATE/bug_request.yml b/.github/ISSUE_TEMPLATE/bug_report.yml similarity index 100% rename from .github/ISSUE_TEMPLATE/bug_request.yml rename to .github/ISSUE_TEMPLATE/bug_report.yml diff --git a/.github/ISSUE_TEMPLATE/enhancement_request.yml b/.github/ISSUE_TEMPLATE/enhancement_request.yml index 223e44f7f..482dce376 100644 --- a/.github/ISSUE_TEMPLATE/enhancement_request.yml +++ b/.github/ISSUE_TEMPLATE/enhancement_request.yml @@ -1,5 +1,5 @@ -name: SuperSonic enhancement -description: Add an enhanment for SuperSonic +name: SuperSonic enhancement request +description: Add an enhancement for SuperSonic title: "[Enhancement] " labels: enhancement diff --git a/.github/ISSUE_TEMPLATE/question_request.yml b/.github/ISSUE_TEMPLATE/question_request.yml index cb61cb28d..b02b5fcf4 100644 --- a/.github/ISSUE_TEMPLATE/question_request.yml +++ b/.github/ISSUE_TEMPLATE/question_request.yml @@ -1,5 +1,5 @@ name: SuperSonic question request -description: Ask a question of SuperSonic +description: Ask a question about SuperSonic title: "[question] " labels: question body: diff --git a/common/pom.xml b/common/pom.xml index 7b4232110..2e8fe5ff6 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -260,6 +260,11 @@ spring-boot-autoconfigure-processor + + com.google.code.gson + gson + + diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/EmbeddingStoreProperties.java b/common/src/main/java/dev/langchain4j/inmemory/spring/EmbeddingStoreProperties.java index 7e88c01f0..f59599de9 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/EmbeddingStoreProperties.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/EmbeddingStoreProperties.java @@ -7,5 +7,5 @@ import lombok.Setter; @Setter class EmbeddingStoreProperties { - private String filePath; + private String persistPath; } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java index 70e0ba36c..e4dd3f56a 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java @@ -1,8 +1,6 @@ package dev.langchain4j.inmemory.spring; -import static dev.langchain4j.inmemory.spring.Properties.PREFIX; - import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -14,6 +12,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import static dev.langchain4j.inmemory.spring.Properties.PREFIX; + @Configuration @EnableConfigurationProperties(Properties.class) public class InMemoryAutoConfig { @@ -22,8 +22,8 @@ public class InMemoryAutoConfig { public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q"; @Bean - @ConditionalOnProperty(PREFIX + ".embedding-store.file-path") - EmbeddingStoreFactory milvusChatModel(Properties properties) { + @ConditionalOnProperty(PREFIX + ".embedding-store.persist-path") + EmbeddingStoreFactory inMemoryChatModel(Properties properties) { return new InMemoryEmbeddingStoreFactory(properties); } diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java index 56edefc6e..969347619 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java @@ -1,18 +1,27 @@ package dev.langchain4j.inmemory.spring; +import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.MapUtils; +import org.apache.commons.lang3.StringUtils; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArraySet; @Slf4j public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { + public static final String PERSISTENT_FILE_PRE = "InMemory."; private static Map> collectionNameToStore = new ConcurrentHashMap<>(); private Properties properties; @@ -28,9 +37,63 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { if (Objects.nonNull(embeddingStore)) { return embeddingStore; } - embeddingStore = new InMemoryEmbeddingStore(); + embeddingStore = reloadFromPersistFile(collectionName); + if (Objects.isNull(embeddingStore)) { + embeddingStore = new InMemoryEmbeddingStore(); + } collectionNameToStore.putIfAbsent(collectionName, embeddingStore); return embeddingStore; } + + private InMemoryEmbeddingStore reloadFromPersistFile(String collectionName) { + Path filePath = getPersistPath(collectionName); + if (Objects.isNull(filePath)) { + return null; + } + InMemoryEmbeddingStore embeddingStore = null; + try { + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); + if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName()) + && !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) { + embeddingStore = InMemoryEmbeddingStore.fromFile(filePath); + embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries); + log.info("embeddingStore reload from file:{}", filePath); + } + } catch (Exception e) { + log.error("load persistFile error, persistFile:" + filePath, e); + } + return embeddingStore; + } + + public synchronized void persistFile() { + if (MapUtils.isEmpty(collectionNameToStore)) { + return; + } + for (Map.Entry> entry : collectionNameToStore.entrySet()) { + Path filePath = getPersistPath(entry.getKey()); + if (Objects.isNull(filePath)) { + continue; + } + try { + Path directoryPath = filePath.getParent(); + if (!Files.exists(directoryPath)) { + Files.createDirectories(directoryPath); + } + entry.getValue().serializeToFile(filePath); + } catch (Exception e) { + log.error("persistFile error, persistFile:" + filePath, e); + } + } + } + + private Path getPersistPath(String collectionName) { + String persistFile = PERSISTENT_FILE_PRE + collectionName; + String persistPath = properties.getEmbeddingStore().getPersistPath(); + if (StringUtils.isEmpty(persistPath)) { + return null; + } + return Paths.get(persistPath, persistFile); + } + } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java b/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java new file mode 100644 index 000000000..05c1475d7 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java @@ -0,0 +1,250 @@ +package dev.langchain4j.store.embedding.inmemory; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory; +import dev.langchain4j.store.embedding.CosineSimilarity; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.RelevanceScore; +import dev.langchain4j.store.embedding.filter.Filter; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArraySet; +import java.util.stream.IntStream; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.spi.ServiceHelper.loadFactories; +import static java.nio.file.StandardOpenOption.CREATE; +import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; +import static java.util.Comparator.comparingDouble; +import static java.util.stream.Collectors.toList; + +/** + * An {@link EmbeddingStore} that stores embeddings in memory. + *

+ * Uses a brute force approach by iterating over all embeddings to find the best matches. + *

+ * This store can be persisted using the {@link #serializeToJson()} and {@link #serializeToFile(Path)} methods. + *

+ * It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link #fromFile(Path)} methods. + * + * @param The class of the object that has been embedded. + * Typically, it is {@link dev.langchain4j.data.segment.TextSegment}. + */ +public class InMemoryEmbeddingStore implements EmbeddingStore { + + public Set> entries = new CopyOnWriteArraySet<>(); + + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + add(id, embedding, null); + } + + @Override + public String add(Embedding embedding, Embedded embedded) { + String id = randomUUID(); + add(id, embedding, embedded); + return id; + } + + public void add(String id, Embedding embedding, Embedded embedded) { + entries.add(new Entry<>(id, embedding, embedded)); + } + + private List add(List> newEntries) { + + entries.addAll(newEntries); + + return newEntries.stream() + .map(entry -> entry.id) + .collect(toList()); + } + + @Override + public List addAll(List embeddings) { + + List> newEntries = embeddings.stream() + .map(embedding -> new Entry(randomUUID(), embedding)) + .collect(toList()); + + return add(newEntries); + } + + @Override + public List addAll(List embeddings, List embedded) { + if (embeddings.size() != embedded.size()) { + throw new IllegalArgumentException("The list of embeddings and embedded must have the same size"); + } + + List> newEntries = IntStream.range(0, embeddings.size()) + .mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i))) + .collect(toList()); + + return add(newEntries); + } + + @Override + public void removeAll(Collection ids) { + ensureNotEmpty(ids, "ids"); + + entries.removeIf(entry -> ids.contains(entry.id)); + } + + @Override + public void removeAll(Filter filter) { + ensureNotNull(filter, "filter"); + + entries.removeIf(entry -> { + if (entry.embedded instanceof TextSegment) { + return filter.test(((TextSegment) entry.embedded).metadata()); + } else if (entry.embedded == null) { + return false; + } else { + throw new UnsupportedOperationException("Not supported yet."); + } + }); + } + + @Override + public void removeAll() { + entries.clear(); + } + + @Override + public EmbeddingSearchResult search(EmbeddingSearchRequest embeddingSearchRequest) { + + Comparator> comparator = comparingDouble(EmbeddingMatch::score); + PriorityQueue> matches = new PriorityQueue<>(comparator); + + Filter filter = embeddingSearchRequest.filter(); + + for (Entry entry : entries) { + + if (filter != null && entry.embedded instanceof TextSegment) { + Metadata metadata = ((TextSegment) entry.embedded).metadata(); + if (!filter.test(metadata)) { + continue; + } + } + + double cosineSimilarity = CosineSimilarity.between(entry.embedding, + embeddingSearchRequest.queryEmbedding()); + double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity); + if (score >= embeddingSearchRequest.minScore()) { + matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded)); + if (matches.size() > embeddingSearchRequest.maxResults()) { + matches.poll(); + } + } + } + + List> result = new ArrayList<>(matches); + result.sort(comparator); + Collections.reverse(result); + + return new EmbeddingSearchResult<>(result); + } + + public String serializeToJson() { + return loadCodec().toJson(this); + } + + public void serializeToFile(Path filePath) { + try { + String json = serializeToJson(); + Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void serializeToFile(String filePath) { + serializeToFile(Paths.get(filePath)); + } + + public static InMemoryEmbeddingStore fromJson(String json) { + return loadCodec().fromJson(json); + } + + public static InMemoryEmbeddingStore fromFile(Path filePath) { + try { + String json = new String(Files.readAllBytes(filePath)); + return fromJson(json); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static InMemoryEmbeddingStore fromFile(String filePath) { + return fromFile(Paths.get(filePath)); + } + + private static class Entry { + + String id; + Embedding embedding; + Embedded embedded; + + Entry(String id, Embedding embedding) { + this(id, embedding, null); + } + + Entry(String id, Embedding embedding, Embedded embedded) { + this.id = ensureNotBlank(id, "id"); + this.embedding = ensureNotNull(embedding, "embedding"); + this.embedded = embedded; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Entry that = (Entry) o; + return Objects.equals(this.id, that.id) + && Objects.equals(this.embedding, that.embedding) + && Objects.equals(this.embedded, that.embedded); + } + + @Override + public int hashCode() { + return Objects.hash(id, embedding, embedded); + } + } + + private static InMemoryEmbeddingStoreJsonCodec loadCodec() { + for (InMemoryEmbeddingStoreJsonCodecFactory factory : + loadFactories(InMemoryEmbeddingStoreJsonCodecFactory.class)) { + return factory.create(); + } + return new GsonInMemoryEmbeddingStoreJsonCodec(); + } +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java index 3f9175391..6404b2eef 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java @@ -42,7 +42,6 @@ public class MetaEmbeddingListener implements ApplicationListener { return; } sleep(); - embeddingService.addCollection(embeddingConfig.getMetaCollectionName()); if (event.getEventType().equals(EventType.ADD)) { embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments); } else if (event.getEventType().equals(EventType.DELETE)) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java index ebfc44465..dc250bad3 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java @@ -5,6 +5,8 @@ import com.tencent.supersonic.common.pojo.DataItem; import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.headless.server.web.service.DimensionService; import com.tencent.supersonic.headless.server.web.service.MetricService; +import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; +import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.TextSegmentConvert; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -28,28 +30,33 @@ public class EmbeddingTask { @Autowired private DimensionService dimensionService; + @Autowired + private EmbeddingStoreFactory embeddingStoreFactory; + @PreDestroy public void onShutdown() { - // embeddingStorePersistentToFile(); + embeddingStorePersistFile(); } - // private void embeddingStorePersistentToFile() { - // if (embeddingService instanceof InMemoryEmbeddingService) { - // log.info("start persistentToFile"); - // ((InMemoryEmbeddingService) embeddingService).persistentToFile(); - // log.info("end persistentToFile"); - // } - // } + private void embeddingStorePersistFile() { + if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) { + log.info("start persistFile"); + InMemoryEmbeddingStoreFactory inMemoryFactory = + (InMemoryEmbeddingStoreFactory) embeddingStoreFactory; + inMemoryFactory.persistFile(); + log.info("end persistFile"); + } + } - @Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}") - public void executeTask() { - // embeddingStorePersistentToFile(); + @Scheduled(cron = "${s2.inMemoryEmbeddingStore.persist.cron:0 0 * * * ?}") + public void executePersistFileTask() { + embeddingStorePersistFile(); } /*** * reload meta embedding */ - @Scheduled(cron = "${reload.meta.embedding.corn:0 0 */2 * * ?}") + @Scheduled(cron = "${s2.reload.meta.embedding.corn:0 0 */2 * * ?}") public void reloadMetaEmbedding() { log.info("reload.meta.embedding start"); try { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/KnowledgeController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/KnowledgeController.java index 50f29d1ba..db2a87b2b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/KnowledgeController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/rest/KnowledgeController.java @@ -141,6 +141,12 @@ public class KnowledgeController { return true; } + @GetMapping("/embedding/persistFile") + public Object executePersistFileTask() { + embeddingTask.executePersistFileTask(); + return true; + } + /** * queryDictValue-返回字典的数据 * @@ -161,8 +167,8 @@ public class KnowledgeController { */ @PostMapping("/dict/file") public String queryDictFilePath(@RequestBody @Valid DictValueReq dictValueReq, - HttpServletRequest request, - HttpServletResponse response) { + HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return taskService.queryDictFilePath(dictValueReq, user); } diff --git a/launchers/standalone/src/main/resources/langchain4j-config.yaml b/launchers/standalone/src/main/resources/langchain4j-config.yaml index 39086c564..8ff4a1e4d 100644 --- a/launchers/standalone/src/main/resources/langchain4j-config.yaml +++ b/launchers/standalone/src/main/resources/langchain4j-config.yaml @@ -13,4 +13,4 @@ langchain4j: embedding-model: model-name: bge-small-zh embedding-store: - file-path: /tmp \ No newline at end of file + persist-path: /tmp \ No newline at end of file diff --git a/launchers/standalone/src/test/resources/langchain4j-config.yaml b/launchers/standalone/src/test/resources/langchain4j-config.yaml index 39086c564..8ff4a1e4d 100644 --- a/launchers/standalone/src/test/resources/langchain4j-config.yaml +++ b/launchers/standalone/src/test/resources/langchain4j-config.yaml @@ -13,4 +13,4 @@ langchain4j: embedding-model: model-name: bge-small-zh embedding-store: - file-path: /tmp \ No newline at end of file + persist-path: /tmp \ No newline at end of file diff --git a/pom.xml b/pom.xml index 79bf0c622..2084ffd38 100644 --- a/pom.xml +++ b/pom.xml @@ -215,6 +215,11 @@ spring-boot-autoconfigure-processor ${spring.version} + + com.google.code.gson + gson + ${gson.version} +