mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(chat) Support configuring embeddingModel or embeddingStore at the agent level. (#1361)
This commit is contained in:
@@ -185,6 +185,8 @@
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
|
||||
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ChatModelConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private String provider;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String modelName;
|
||||
private Double temperature = 0.0d;
|
||||
private Long timeOut = 60L;
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class EmbeddingModelConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private String provider;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String modelName;
|
||||
private String modelPath;
|
||||
private String vocabularyPath;
|
||||
private Integer maxRetries;
|
||||
private Integer maxToken;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
public class EmbeddingStoreConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private String provider;
|
||||
private String persistPath;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private Long timeOut = 60L;
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class LLMConfig {
|
||||
|
||||
private String provider;
|
||||
|
||||
private String baseUrl;
|
||||
|
||||
private String apiKey;
|
||||
|
||||
private String modelName;
|
||||
|
||||
private Double temperature = 0.0d;
|
||||
|
||||
private Long timeOut = 60L;
|
||||
|
||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
}
|
||||
|
||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName,
|
||||
double temperature) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
this.temperature = temperature;
|
||||
}
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(apiKey);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ModelConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private ChatModelConfig chatModel;
|
||||
private EmbeddingModelConfig embeddingModel;
|
||||
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
public interface ChatLanguageModelFactory {
|
||||
ChatLanguageModel create(LLMConfig llmConfig);
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ChatLanguageModelProvider {
|
||||
private static final Map<String, ChatLanguageModelFactory> factories = new HashMap<>();
|
||||
|
||||
static {
|
||||
factories.put(ModelProvider.OPEN_AI.name(), new OpenAiChatModelFactory());
|
||||
factories.put(ModelProvider.LOCAL_AI.name(), new LocalAiChatModelFactory());
|
||||
factories.put(ModelProvider.OLLAMA.name(), new OllamaChatModelFactory());
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provide(LLMConfig llmConfig) {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
|
||||
ChatLanguageModelFactory factory = factories.get(llmConfig.getProvider().toUpperCase());
|
||||
if (factory != null) {
|
||||
return factory.create(llmConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported provider: " + llmConfig.getProvider());
|
||||
}
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class LocalAiChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
public enum ModelProvider {
|
||||
OPEN_AI,
|
||||
HUGGING_FACE,
|
||||
LOCAL_AI,
|
||||
IN_PROCESS,
|
||||
OLLAMA
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class OllamaChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return OllamaChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class OpenAiChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.apiKey(llmConfig.keyDecrypt())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(chatModel.getBaseUrl())
|
||||
.apiKey(chatModel.getApiKey())
|
||||
.deploymentName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut() == null ? 0L : chatModel.getTimeOut()));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
AzureOpenAiEmbeddingModel.Builder builder = AzureOpenAiEmbeddingModel.builder()
|
||||
.endpoint(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.deploymentName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
|
||||
&& embeddingModelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.AZURE, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return QwenChatModel.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.apiKey(chatModel.getApiKey())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature() == null ? 0L :
|
||||
chatModel.getTemperature().floatValue())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QwenEmbeddingModel.builder()
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.DASHSCOPE, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
|
||||
public interface EmbeddingStoreFactory {
|
||||
EmbeddingStore createEmbeddingStore(EmbeddingStoreConfig config);
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
public enum EmbeddingStoreType {
|
||||
IN_MEMORY,
|
||||
MILVUS,
|
||||
CHROMA
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.ALL_MINILM_L6_V2;
|
||||
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
|
||||
|
||||
@Service
|
||||
public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
String modelPath = embeddingModel.getModelPath();
|
||||
String vocabularyPath = embeddingModel.getVocabularyPath();
|
||||
if (StringUtils.isNotBlank(modelPath) && StringUtils.isNotBlank(vocabularyPath)) {
|
||||
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
|
||||
}
|
||||
String modelName = embeddingModel.getModelName();
|
||||
if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
|
||||
return new BgeSmallZhEmbeddingModel();
|
||||
}
|
||||
if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
|
||||
return new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
}
|
||||
return new BgeSmallZhEmbeddingModel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.IN_MEMORY, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
return LocalAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModel.getBaseUrl())
|
||||
.modelName(embeddingModel.getModelName())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.LOCAL_AI, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
public interface ModelFactory {
|
||||
ChatLanguageModel createChatModel(ChatModelConfig llmConfig);
|
||||
|
||||
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ModelProvider {
|
||||
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
||||
|
||||
public static void add(Provider provider, ModelFactory modelFactory) {
|
||||
factories.put(provider.name(), modelFactory);
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createChatModel(llmConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider());
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || modelConfig.getChatModel() == null
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
ChatModelConfig chatModel = modelConfig.getChatModel();
|
||||
ModelFactory modelFactory = factories.get(chatModel.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createChatModel(chatModel);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
|
||||
}
|
||||
|
||||
public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
|
||||
return ContextUtils.getBean(EmbeddingModel.class);
|
||||
}
|
||||
EmbeddingModelConfig embeddingModel = modelConfig.getEmbeddingModel();
|
||||
|
||||
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createEmbeddingModel(embeddingModel);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported EmbeddingModel provider: " + embeddingModel.getProvider());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return OllamaChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return OllamaEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.OLLAMA, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.apiKey(chatModel.keyDecrypt())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
return OpenAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModel.getBaseUrl())
|
||||
.apiKey(embeddingModel.getApiKey())
|
||||
.modelName(embeddingModel.getModelName())
|
||||
.maxRetries(embeddingModel.getMaxRetries())
|
||||
.logRequests(embeddingModel.getLogRequests())
|
||||
.logResponses(embeddingModel.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.OPEN_AI, this);
|
||||
}
|
||||
}
|
||||
12
common/src/main/java/dev/langchain4j/provider/Provider.java
Normal file
12
common/src/main/java/dev/langchain4j/provider/Provider.java
Normal file
@@ -0,0 +1,12 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
public enum Provider {
|
||||
OPEN_AI,
|
||||
OLLAMA,
|
||||
LOCAL_AI,
|
||||
IN_MEMORY,
|
||||
ZHIPU,
|
||||
AZURE,
|
||||
QIANFAN,
|
||||
DASHSCOPE
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QianfanEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.QIANFAN, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return ZhipuAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.model(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.ZHIPU, this);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user