From d5c78d87e70b1154cf12b0ae7062c6db2ee759f4 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sun, 14 Jul 2024 21:50:12 +0800 Subject: [PATCH] (improvement)(chat) Support integrating embeddingStore into system settings. (#1405) --- .../config/EmbeddingModelParameterConfig.java | 2 +- .../common/pojo/EmbeddingStoreConfig.java | 2 ++ .../service/impl/EmbeddingServiceImpl.java | 9 +++--- .../chroma/spring/ChromaAutoConfig.java | 2 +- .../spring/ChromaEmbeddingStoreFactory.java | 13 +++----- .../spring/EmbeddingStoreProperties.java | 2 ++ .../embedding/BaseEmbeddingStoreFactory.java | 2 +- .../EmbeddingStoreFactoryProvider.java | 32 +++++++++++++------ .../server/task/MetaEmbeddingTask.java | 5 ++- 9 files changed, 41 insertions(+), 28 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java index c403ecb03..cd176b031 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java @@ -15,7 +15,7 @@ import org.springframework.stereotype.Service; import java.util.List; -@Service("EmbeddingModelConfig") +@Service("EmbeddingModelParameterConfig") @Slf4j public class EmbeddingModelParameterConfig extends ParameterConfig { diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java index e10b19e99..c92af21e6 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.common.pojo; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import java.io.Serializable; @@ -11,6 +12,7 @@ import java.io.Serializable; @Builder @AllArgsConstructor @NoArgsConstructor +@EqualsAndHashCode public class EmbeddingStoreConfig implements Serializable { private static final long serialVersionUID = 1L; private String provider; diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index c9b203a55..919b641f7 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -14,6 +14,7 @@ import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreFactory; +import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; @@ -40,10 +41,6 @@ import java.util.stream.Collectors; @Service @Slf4j public class EmbeddingServiceImpl implements EmbeddingService { - - @Autowired - private EmbeddingStoreFactory embeddingStoreFactory; - @Autowired private EmbeddingModelParameterConfig embeddingModelParameterConfig; @@ -54,6 +51,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { @Override public void addQuery(String collectionName, List queries) { + EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory(); EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); for (TextSegment query : queries) { @@ -102,6 +100,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { @Override public void deleteQuery(String collectionName, List queries) { //Not supported yet in Milvus and Chroma + EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory(); EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); try { if (embeddingStore instanceof InMemoryEmbeddingStore) { @@ -122,7 +121,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { @Override public List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { List results = new ArrayList<>(); - + EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory(); EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); List queryTextsList = retrieveQuery.getQueryTextsList(); Map filterCondition = retrieveQuery.getFilterCondition(); diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaAutoConfig.java b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaAutoConfig.java index 276966088..fc68dd575 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaAutoConfig.java @@ -16,6 +16,6 @@ public class ChromaAutoConfig { @Bean @ConditionalOnProperty(PREFIX + ".embedding-store.base-url") EmbeddingStoreFactory chromaChatModel(Properties properties) { - return new ChromaEmbeddingStoreFactory(properties); + return new ChromaEmbeddingStoreFactory(properties.getEmbeddingStore()); } } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java index d38270f0d..dca01b3af 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java @@ -12,19 +12,18 @@ import java.time.Duration; @Slf4j public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { - private Properties properties; + private EmbeddingStoreProperties storeProperties; public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { this(createPropertiesFromConfig(storeConfig)); } - public ChromaEmbeddingStoreFactory(Properties properties) { - this.properties = properties; + public ChromaEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) { + this.storeProperties = storeProperties; } @Override public EmbeddingStore createEmbeddingStore(String collectionName) { - EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore(); return ChromaEmbeddingStore.builder() .baseUrl(storeProperties.getBaseUrl()) .collectionName(collectionName) @@ -32,12 +31,10 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { .build(); } - private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { - Properties properties = new Properties(); + private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); BeanUtils.copyProperties(storeConfig, embeddingStore); embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut())); - properties.setEmbeddingStore(embeddingStore); - return properties; + return embeddingStore; } } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java b/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java index b30bdb252..72e7f5d9e 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java @@ -2,11 +2,13 @@ package dev.langchain4j.chroma.spring; import lombok.Getter; import lombok.Setter; +import lombok.ToString; import java.time.Duration; @Getter @Setter +@ToString public class EmbeddingStoreProperties { private String baseUrl; diff --git a/common/src/main/java/dev/langchain4j/store/embedding/BaseEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/store/embedding/BaseEmbeddingStoreFactory.java index 2f895b1b9..2c92f1a24 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/BaseEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/BaseEmbeddingStoreFactory.java @@ -6,7 +6,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory { - protected static final Map> collectionNameToStore = new ConcurrentHashMap<>(); + protected final Map> collectionNameToStore = new ConcurrentHashMap<>(); public EmbeddingStore create(String collectionName) { return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore); diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java index d3dc5876c..dcac744e0 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java @@ -1,5 +1,6 @@ package dev.langchain4j.store.embedding; +import com.tencent.supersonic.common.config.EmbeddingStoreParameterConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory; @@ -7,20 +8,33 @@ import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory; import org.apache.commons.lang3.StringUtils; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + public class EmbeddingStoreFactoryProvider { - public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig storeConfig) { - if (storeConfig == null || StringUtils.isBlank(storeConfig.getProvider())) { + protected static final Map factoryMap = new ConcurrentHashMap<>(); + + public static EmbeddingStoreFactory getFactory() { + EmbeddingStoreParameterConfig parameterConfig = ContextUtils.getBean(EmbeddingStoreParameterConfig.class); + return getFactory(parameterConfig.convert()); + } + + public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig embeddingStoreConfig) { + if (embeddingStoreConfig == null || StringUtils.isBlank(embeddingStoreConfig.getProvider())) { return ContextUtils.getBean(EmbeddingStoreFactory.class); } - if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(storeConfig.getProvider())) { - return new ChromaEmbeddingStoreFactory(storeConfig); + if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) { + return factoryMap.computeIfAbsent(embeddingStoreConfig, + storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig)); } - if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(storeConfig.getProvider())) { - return new MilvusEmbeddingStoreFactory(storeConfig); + if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) { + return factoryMap.computeIfAbsent(embeddingStoreConfig, + storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig)); } - if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(storeConfig.getProvider())) { - return new InMemoryEmbeddingStoreFactory(storeConfig); + if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) { + return factoryMap.computeIfAbsent(embeddingStoreConfig, + storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig)); } - throw new RuntimeException("Unsupported EmbeddingStore provider: " + storeConfig.getProvider()); + throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: " + embeddingStoreConfig.getProvider()); } } \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/MetaEmbeddingTask.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/MetaEmbeddingTask.java index 3a1081b5e..db41671ec 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/MetaEmbeddingTask.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/MetaEmbeddingTask.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.server.web.service.DimensionService; import com.tencent.supersonic.headless.server.web.service.MetricService; import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStoreFactory; +import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider; import dev.langchain4j.store.embedding.TextSegmentConvert; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -35,15 +36,13 @@ public class MetaEmbeddingTask implements CommandLineRunner { @Autowired private DimensionService dimensionService; - @Autowired - private EmbeddingStoreFactory embeddingStoreFactory; - @PreDestroy public void onShutdown() { embeddingStorePersistFile(); } private void embeddingStorePersistFile() { + EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory(); if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) { long startTime = System.currentTimeMillis(); InMemoryEmbeddingStoreFactory inMemoryFactory =