mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-21 22:34:28 +08:00
(improvement)(chat) Support embedding store configuration. (#1363)
This commit is contained in:
@@ -13,15 +13,16 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "AZURE";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(chatModel.getBaseUrl())
|
||||
.apiKey(chatModel.getApiKey())
|
||||
.deploymentName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut() == null ? 0L : chatModel.getTimeOut()));
|
||||
.endpoint(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.deploymentName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@@ -39,6 +40,6 @@ public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.AZURE, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,14 +11,16 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "DASHSCOPE";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return QwenChatModel.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.apiKey(chatModel.getApiKey())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature() == null ? 0L :
|
||||
chatModel.getTemperature().floatValue())
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature() == null ? 0L :
|
||||
modelConfig.getTemperature().floatValue())
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -32,6 +34,6 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.DASHSCOPE, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
public enum EmbeddingStoreType {
|
||||
IN_MEMORY,
|
||||
MILVUS,
|
||||
CHROMA
|
||||
}
|
||||
@@ -16,9 +16,11 @@ import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
|
||||
|
||||
@Service
|
||||
public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "IN_MEMORY";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return null;
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
throw new UnsupportedOperationException("Not supported yet.");
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -40,6 +42,6 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.IN_MEMORY, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,14 +13,16 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "LOCAL_AI";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -34,6 +36,6 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.LOCAL_AI, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
public interface ModelFactory {
|
||||
ChatLanguageModel createChatModel(ChatModelConfig llmConfig);
|
||||
ChatLanguageModel createChatModel(ChatModelConfig modelConfig);
|
||||
|
||||
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
|
||||
}
|
||||
|
||||
@@ -15,24 +15,11 @@ import java.util.Objects;
|
||||
public class ModelProvider {
|
||||
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
||||
|
||||
public static void add(Provider provider, ModelFactory modelFactory) {
|
||||
factories.put(provider.name(), modelFactory);
|
||||
public static void add(String provider, ModelFactory modelFactory) {
|
||||
factories.put(provider, modelFactory);
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createChatModel(llmConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider());
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) {
|
||||
public static ChatLanguageModel getChatModel(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || modelConfig.getChatModel() == null
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
|
||||
@@ -47,7 +34,7 @@ public class ModelProvider {
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
|
||||
}
|
||||
|
||||
public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) {
|
||||
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
|
||||
|
||||
@@ -13,14 +13,16 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "OLLAMA";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return OllamaChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -37,6 +39,6 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.OLLAMA, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,15 +13,17 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "OPEN_AI";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.apiKey(chatModel.keyDecrypt())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.apiKey(modelConfig.keyDecrypt())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -39,6 +41,6 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.OPEN_AI, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
public enum Provider {
|
||||
OPEN_AI,
|
||||
OLLAMA,
|
||||
LOCAL_AI,
|
||||
IN_MEMORY,
|
||||
ZHIPU,
|
||||
AZURE,
|
||||
QIANFAN,
|
||||
DASHSCOPE
|
||||
}
|
||||
@@ -10,8 +10,11 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "QIANFAN";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -29,6 +32,6 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.QIANFAN, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,10 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "ZHIPU";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -29,6 +31,6 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.ZHIPU, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user