diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java index cd176b031..cecfedb7e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java @@ -20,7 +20,7 @@ import java.util.List; public class EmbeddingModelParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_MODEL_PROVIDER = - new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, + new Parameter("s2.embedding.model.provider", "", "接口协议", "", "string", "向量模型配置", Lists.newArrayList(InMemoryModelFactory.PROVIDER, diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java index 5884f28b6..2861061c0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java @@ -3,7 +3,6 @@ package com.tencent.supersonic.common.config; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.Parameter; -import dev.langchain4j.provider.InMemoryModelFactory; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -13,7 +12,7 @@ import java.util.List; @Slf4j public class EmbeddingStoreParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_STORE_PROVIDER = - new Parameter("s2.embedding.store.provider", InMemoryModelFactory.PROVIDER, + new Parameter("s2.embedding.store.provider", "", "向量库类型", "", "string", "向量库配置"); diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 919b641f7..b0e89b9b4 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -2,8 +2,6 @@ package com.tencent.supersonic.common.service.impl; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig; -import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.service.EmbeddingService; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; @@ -26,7 +24,6 @@ import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.MapUtils; import org.apache.commons.collections4.CollectionUtils; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.ArrayList; @@ -41,8 +38,6 @@ import java.util.stream.Collectors; @Service @Slf4j public class EmbeddingServiceImpl implements EmbeddingService { - @Autowired - private EmbeddingModelParameterConfig embeddingModelParameterConfig; private Cache cache = CacheBuilder.newBuilder() .maximumSize(10000) @@ -57,7 +52,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { for (TextSegment query : queries) { String question = query.text(); try { - EmbeddingModel embeddingModel = getEmbeddingModel(); + EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); Embedding embedding = embeddingModel.embed(question).content(); boolean existSegment = existSegment(embeddingStore, query, embedding); if (existSegment) { @@ -126,7 +121,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { List queryTextsList = retrieveQuery.getQueryTextsList(); Map filterCondition = retrieveQuery.getFilterCondition(); for (String queryText : queryTextsList) { - EmbeddingModel embeddingModel = getEmbeddingModel(); + EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); Embedding embeddedText = embeddingModel.embed(queryText).content(); Filter filter = createCombinedFilter(filterCondition); EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() @@ -174,9 +169,4 @@ public class EmbeddingServiceImpl implements EmbeddingService { } return result; } - - private EmbeddingModel getEmbeddingModel() { - EmbeddingModelConfig embeddingModelConfig = embeddingModelParameterConfig.convert(); - return ModelProvider.getEmbeddingModel(embeddingModelConfig); - } } diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java index e4dd3f56a..ab79ea5d7 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java @@ -24,7 +24,7 @@ public class InMemoryAutoConfig { @Bean @ConditionalOnProperty(PREFIX + ".embedding-store.persist-path") EmbeddingStoreFactory inMemoryChatModel(Properties properties) { - return new InMemoryEmbeddingStoreFactory(properties); + return new InMemoryEmbeddingStoreFactory(properties.getEmbeddingStore()); } @Bean diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java index 91922074d..20aa64838 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java @@ -23,22 +23,20 @@ import java.util.concurrent.CopyOnWriteArraySet; public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { public static final String PERSISTENT_FILE_PRE = "InMemory."; - private Properties properties; + private EmbeddingStoreProperties embeddingStore; public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { this(createPropertiesFromConfig(storeConfig)); } - public InMemoryEmbeddingStoreFactory(Properties properties) { - this.properties = properties; + public InMemoryEmbeddingStoreFactory(EmbeddingStoreProperties embeddingStore) { + this.embeddingStore = embeddingStore; } - private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { - Properties properties = new Properties(); + private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); BeanUtils.copyProperties(storeConfig, embeddingStore); - properties.setEmbeddingStore(embeddingStore); - return properties; + return embeddingStore; } @Override @@ -97,7 +95,7 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { private Path getPersistPath(String collectionName) { String persistFile = PERSISTENT_FILE_PRE + collectionName; - String persistPath = properties.getEmbeddingStore().getPersistPath(); + String persistPath = embeddingStore.getPersistPath(); if (StringUtils.isEmpty(persistPath)) { return null; } diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusAutoConfig.java b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusAutoConfig.java index eeb5e9626..47ff35ba6 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusAutoConfig.java @@ -15,6 +15,6 @@ public class MilvusAutoConfig { @Bean @ConditionalOnProperty(PREFIX + ".embedding-store.uri") EmbeddingStoreFactory milvusChatModel(Properties properties) { - return new MilvusEmbeddingStoreFactory(properties); + return new MilvusEmbeddingStoreFactory(properties.getEmbeddingStore()); } } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java index 074a30f70..2ff3733fa 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java @@ -8,28 +8,25 @@ import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; import org.springframework.beans.BeanUtils; public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { - private final Properties properties; + private final EmbeddingStoreProperties storeProperties; public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { this(createPropertiesFromConfig(storeConfig)); } - public MilvusEmbeddingStoreFactory(Properties properties) { - this.properties = properties; + public MilvusEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) { + this.storeProperties = storeProperties; } - private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { - Properties properties = new Properties(); + private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); BeanUtils.copyProperties(storeConfig, embeddingStore); embeddingStore.setUri(storeConfig.getBaseUrl()); - properties.setEmbeddingStore(embeddingStore); - return properties; + return embeddingStore; } @Override public EmbeddingStore createEmbeddingStore(String collectionName) { - EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore(); return MilvusEmbeddingStore.builder() .host(storeProperties.getHost()) .port(storeProperties.getPort()) diff --git a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java index ab3c8f142..a2e7aeed5 100644 --- a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java +++ b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java @@ -1,5 +1,6 @@ package dev.langchain4j.provider; +import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.util.ContextUtils; @@ -31,6 +32,13 @@ public class ModelProvider { throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider()); } + public static EmbeddingModel getEmbeddingModel() { + EmbeddingModelParameterConfig parameterConfig = ContextUtils.getBean( + EmbeddingModelParameterConfig.class); + EmbeddingModelConfig embeddingModelConfig = parameterConfig.convert(); + return getEmbeddingModel(embeddingModelConfig); + } + public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) { if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) { return ContextUtils.getBean(EmbeddingModel.class);