(improvement)(chat) Support embedding store configuration. (#1363)

This commit is contained in:
lexluo09
2024-07-07 00:30:19 +08:00
committed by GitHub
parent 3f460429e6
commit 4d7bfe07aa
37 changed files with 185 additions and 119 deletions

View File

@@ -1,15 +1,23 @@
package dev.langchain4j.chroma.spring;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import java.time.Duration;
@Slf4j
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private Properties properties;
public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig));
}
public ChromaEmbeddingStoreFactory(Properties properties) {
this.properties = properties;
}
@@ -23,4 +31,13 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
.timeout(storeProperties.getTimeout())
.build();
}
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut()));
properties.setEmbeddingStore(embeddingStore);
return properties;
}
}

View File

@@ -7,7 +7,7 @@ import java.time.Duration;
@Getter
@Setter
class EmbeddingStoreProperties {
public class EmbeddingStoreProperties {
private String baseUrl;
private Duration timeout;

View File

@@ -1,6 +1,7 @@
package dev.langchain4j.inmemory.spring;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
@@ -9,6 +10,7 @@ import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import java.nio.file.Files;
import java.nio.file.Path;
@@ -23,11 +25,22 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private Properties properties;
public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig));
}
public InMemoryEmbeddingStoreFactory(Properties properties) {
this.properties = properties;
}
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
properties.setEmbeddingStore(embeddingStore);
return properties;
}
@Override
public synchronized EmbeddingStore createEmbeddingStore(String collectionName) {
InMemoryEmbeddingStore<TextSegment> embeddingStore = reloadFromPersistFile(collectionName);

View File

@@ -1,17 +1,32 @@
package dev.langchain4j.milvus.spring;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
import org.springframework.beans.BeanUtils;
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private final Properties properties;
public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig));
}
public MilvusEmbeddingStoreFactory(Properties properties) {
this.properties = properties;
}
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
embeddingStore.setUri(storeConfig.getBaseUrl());
properties.setEmbeddingStore(embeddingStore);
return properties;
}
@Override
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();

View File

@@ -13,15 +13,16 @@ import java.time.Duration;
@Service
public class AzureModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "AZURE";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(chatModel.getBaseUrl())
.apiKey(chatModel.getApiKey())
.deploymentName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut() == null ? 0L : chatModel.getTimeOut()));
.endpoint(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()));
return builder.build();
}
@@ -39,6 +40,6 @@ public class AzureModelFactory implements ModelFactory, InitializingBean {
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.AZURE, this);
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -11,14 +11,16 @@ import org.springframework.stereotype.Service;
@Service
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return QwenChatModel.builder()
.baseUrl(chatModel.getBaseUrl())
.apiKey(chatModel.getApiKey())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature() == null ? 0L :
chatModel.getTemperature().floatValue())
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature() == null ? 0L :
modelConfig.getTemperature().floatValue())
.build();
}
@@ -32,6 +34,6 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.DASHSCOPE, this);
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -16,9 +16,11 @@ import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
@Service
public class InMemoryModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "IN_MEMORY";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
return null;
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
@@ -40,6 +42,6 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.IN_MEMORY, this);
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -13,14 +13,16 @@ import java.time.Duration;
@Service
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "LOCAL_AI";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return LocalAiChatModel
.builder()
.baseUrl(chatModel.getBaseUrl())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.build();
}
@@ -34,6 +36,6 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.LOCAL_AI, this);
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -6,7 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
public interface ModelFactory {
ChatLanguageModel createChatModel(ChatModelConfig llmConfig);
ChatLanguageModel createChatModel(ChatModelConfig modelConfig);
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
}

View File

@@ -15,24 +15,11 @@ import java.util.Objects;
public class ModelProvider {
private static final Map<String, ModelFactory> factories = new HashMap<>();
public static void add(Provider provider, ModelFactory modelFactory) {
factories.put(provider.name(), modelFactory);
public static void add(String provider, ModelFactory modelFactory) {
factories.put(provider, modelFactory);
}
public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createChatModel(llmConfig);
}
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider());
}
public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) {
public static ChatLanguageModel getChatModel(ModelConfig modelConfig) {
if (modelConfig == null || modelConfig.getChatModel() == null
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
@@ -47,7 +34,7 @@ public class ModelProvider {
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
}
public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) {
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {

View File

@@ -13,14 +13,16 @@ import java.time.Duration;
@Service
public class OllamaModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OLLAMA";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OllamaChatModel
.builder()
.baseUrl(chatModel.getBaseUrl())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.build();
}
@@ -37,6 +39,6 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.OLLAMA, this);
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -13,15 +13,17 @@ import java.time.Duration;
@Service
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OPEN_AI";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OpenAiChatModel
.builder()
.baseUrl(chatModel.getBaseUrl())
.modelName(chatModel.getModelName())
.apiKey(chatModel.keyDecrypt())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.apiKey(modelConfig.keyDecrypt())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.build();
}
@@ -39,6 +41,6 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.OPEN_AI, this);
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -1,12 +0,0 @@
package dev.langchain4j.provider;
public enum Provider {
OPEN_AI,
OLLAMA,
LOCAL_AI,
IN_MEMORY,
ZHIPU,
AZURE,
QIANFAN,
DASHSCOPE
}

View File

@@ -10,8 +10,11 @@ import org.springframework.stereotype.Service;
@Service
public class QianfanModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "QIANFAN";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null;
}
@@ -29,6 +32,6 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.QIANFAN, this);
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -10,8 +10,10 @@ import org.springframework.stereotype.Service;
@Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null;
}
@@ -29,6 +31,6 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.ZHIPU, this);
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -0,0 +1,26 @@
package dev.langchain4j.store.embedding;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
import org.apache.commons.lang3.StringUtils;
public class EmbeddingStoreFactoryProvider {
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig storeConfig) {
if (storeConfig == null || StringUtils.isBlank(storeConfig.getProvider())) {
return ContextUtils.getBean(EmbeddingStoreFactory.class);
}
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(storeConfig.getProvider())) {
return new ChromaEmbeddingStoreFactory(storeConfig);
}
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(storeConfig.getProvider())) {
return new MilvusEmbeddingStoreFactory(storeConfig);
}
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(storeConfig.getProvider())) {
return new InMemoryEmbeddingStoreFactory(storeConfig);
}
throw new RuntimeException("Unsupported EmbeddingStore provider: " + storeConfig.getProvider());
}
}

View File

@@ -1,4 +1,4 @@
package dev.langchain4j.provider;
package dev.langchain4j.store.embedding;
public enum EmbeddingStoreType {
IN_MEMORY,