mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
281 lines
14 KiB
Java
281 lines
14 KiB
Java
package dev.langchain4j;
|
|
|
|
import static dev.langchain4j.ModelProvider.OPEN_AI;
|
|
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
|
|
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
|
|
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
|
|
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
|
|
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
|
|
import dev.langchain4j.model.language.LanguageModel;
|
|
import dev.langchain4j.model.localai.LocalAiChatModel;
|
|
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
|
|
import dev.langchain4j.model.localai.LocalAiLanguageModel;
|
|
import dev.langchain4j.model.moderation.ModerationModel;
|
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
|
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
|
import dev.langchain4j.model.openai.OpenAiLanguageModel;
|
|
import dev.langchain4j.model.openai.OpenAiModerationModel;
|
|
import java.util.Arrays;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
|
import org.springframework.context.annotation.Bean;
|
|
import org.springframework.context.annotation.Configuration;
|
|
import org.springframework.context.annotation.Lazy;
|
|
import org.springframework.context.annotation.Primary;
|
|
|
|
@Configuration
|
|
@EnableConfigurationProperties(LangChain4jProperties.class)
|
|
public class S2LangChain4jAutoConfiguration {
|
|
|
|
@Autowired
|
|
private LangChain4jProperties properties;
|
|
|
|
@Bean
|
|
@Lazy
|
|
@ConditionalOnMissingBean
|
|
ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
|
|
if (properties.getChatModel() == null) {
|
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n"
|
|
+ "langchain4j.chat-model.provider = openai\n"
|
|
+ "langchain4j.chat-model.openai.api-key = sk-...\n");
|
|
}
|
|
|
|
switch (properties.getChatModel().getProvider()) {
|
|
|
|
case OPEN_AI:
|
|
OpenAi openAi = properties.getChatModel().getOpenAi();
|
|
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
|
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.openai.api-key' property");
|
|
}
|
|
return OpenAiChatModel.builder()
|
|
.baseUrl(openAi.getBaseUrl())
|
|
.apiKey(openAi.getApiKey())
|
|
.modelName(openAi.getModelName())
|
|
.temperature(openAi.getTemperature())
|
|
.topP(openAi.getTopP())
|
|
.maxTokens(openAi.getMaxTokens())
|
|
.presencePenalty(openAi.getPresencePenalty())
|
|
.frequencyPenalty(openAi.getFrequencyPenalty())
|
|
.timeout(openAi.getTimeout())
|
|
.maxRetries(openAi.getMaxRetries())
|
|
.logRequests(openAi.getLogRequests())
|
|
.logResponses(openAi.getLogResponses())
|
|
.build();
|
|
|
|
case HUGGING_FACE:
|
|
HuggingFace huggingFace = properties.getChatModel().getHuggingFace();
|
|
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.chat-model.huggingface.access-token' property");
|
|
}
|
|
return HuggingFaceChatModel.builder()
|
|
.accessToken(huggingFace.getAccessToken())
|
|
.modelId(huggingFace.getModelId())
|
|
.timeout(huggingFace.getTimeout())
|
|
.temperature(huggingFace.getTemperature())
|
|
.maxNewTokens(huggingFace.getMaxNewTokens())
|
|
.returnFullText(huggingFace.getReturnFullText())
|
|
.waitForModel(huggingFace.getWaitForModel())
|
|
.build();
|
|
|
|
case LOCAL_AI:
|
|
LocalAi localAi = properties.getChatModel().getLocalAi();
|
|
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
|
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.localai.base-url' property");
|
|
}
|
|
if (isNullOrBlank(localAi.getModelName())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.chat-model.localai.model-name' property");
|
|
}
|
|
return LocalAiChatModel.builder()
|
|
.baseUrl(localAi.getBaseUrl())
|
|
.modelName(localAi.getModelName())
|
|
.temperature(localAi.getTemperature())
|
|
.topP(localAi.getTopP())
|
|
.maxTokens(localAi.getMaxTokens())
|
|
.timeout(localAi.getTimeout())
|
|
.maxRetries(localAi.getMaxRetries())
|
|
.logRequests(localAi.getLogRequests())
|
|
.logResponses(localAi.getLogResponses())
|
|
.build();
|
|
|
|
default:
|
|
throw illegalConfiguration("Unsupported chat model provider: %s",
|
|
properties.getChatModel().getProvider());
|
|
}
|
|
}
|
|
|
|
@Bean
|
|
@Lazy
|
|
@ConditionalOnMissingBean
|
|
LanguageModel languageModel(LangChain4jProperties properties) {
|
|
if (properties.getLanguageModel() == null) {
|
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n"
|
|
+ "langchain4j.language-model.provider = openai\n"
|
|
+ "langchain4j.language-model.openai.api-key = sk-...\n");
|
|
}
|
|
|
|
switch (properties.getLanguageModel().getProvider()) {
|
|
|
|
case OPEN_AI:
|
|
OpenAi openAi = properties.getLanguageModel().getOpenAi();
|
|
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.language-model.openai.api-key' property");
|
|
}
|
|
return OpenAiLanguageModel.builder()
|
|
.apiKey(openAi.getApiKey())
|
|
.baseUrl(openAi.getBaseUrl())
|
|
.modelName(openAi.getModelName())
|
|
.temperature(openAi.getTemperature())
|
|
.timeout(openAi.getTimeout())
|
|
.maxRetries(openAi.getMaxRetries())
|
|
.logRequests(openAi.getLogRequests())
|
|
.logResponses(openAi.getLogResponses())
|
|
.build();
|
|
|
|
case HUGGING_FACE:
|
|
HuggingFace huggingFace = properties.getLanguageModel().getHuggingFace();
|
|
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.language-model.huggingface.access-token' property");
|
|
}
|
|
return HuggingFaceLanguageModel.builder()
|
|
.accessToken(huggingFace.getAccessToken())
|
|
.modelId(huggingFace.getModelId())
|
|
.timeout(huggingFace.getTimeout())
|
|
.temperature(huggingFace.getTemperature())
|
|
.maxNewTokens(huggingFace.getMaxNewTokens())
|
|
.returnFullText(huggingFace.getReturnFullText())
|
|
.waitForModel(huggingFace.getWaitForModel())
|
|
.build();
|
|
|
|
case LOCAL_AI:
|
|
LocalAi localAi = properties.getLanguageModel().getLocalAi();
|
|
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.language-model.localai.base-url' property");
|
|
}
|
|
if (isNullOrBlank(localAi.getModelName())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.language-model.localai.model-name' property");
|
|
}
|
|
return LocalAiLanguageModel.builder()
|
|
.baseUrl(localAi.getBaseUrl())
|
|
.modelName(localAi.getModelName())
|
|
.temperature(localAi.getTemperature())
|
|
.topP(localAi.getTopP())
|
|
.maxTokens(localAi.getMaxTokens())
|
|
.timeout(localAi.getTimeout())
|
|
.maxRetries(localAi.getMaxRetries())
|
|
.logRequests(localAi.getLogRequests())
|
|
.logResponses(localAi.getLogResponses())
|
|
.build();
|
|
|
|
default:
|
|
throw illegalConfiguration("Unsupported language model provider: %s",
|
|
properties.getLanguageModel().getProvider());
|
|
}
|
|
}
|
|
|
|
@Bean
|
|
@Lazy
|
|
@ConditionalOnMissingBean
|
|
@Primary
|
|
EmbeddingModel embeddingModel(LangChain4jProperties properties) {
|
|
|
|
if (properties.getEmbeddingModel() == null || !Arrays.stream(ModelProvider.values())
|
|
.anyMatch(provider -> provider.equals(properties.getEmbeddingModel().getProvider()))) {
|
|
return new AllMiniLmL6V2EmbeddingModel();
|
|
}
|
|
|
|
switch (properties.getEmbeddingModel().getProvider()) {
|
|
|
|
case OPEN_AI:
|
|
OpenAi openAi = properties.getEmbeddingModel().getOpenAi();
|
|
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.embedding-model.openai.api-key' property");
|
|
}
|
|
return OpenAiEmbeddingModel.builder()
|
|
.apiKey(openAi.getApiKey())
|
|
.baseUrl(openAi.getBaseUrl())
|
|
.modelName(openAi.getModelName())
|
|
.timeout(openAi.getTimeout())
|
|
.maxRetries(openAi.getMaxRetries())
|
|
.logRequests(openAi.getLogRequests())
|
|
.logResponses(openAi.getLogResponses())
|
|
.build();
|
|
|
|
case HUGGING_FACE:
|
|
HuggingFace huggingFace = properties.getEmbeddingModel().getHuggingFace();
|
|
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.embedding-model.huggingface.access-token' property");
|
|
}
|
|
return HuggingFaceEmbeddingModel.builder()
|
|
.accessToken(huggingFace.getAccessToken())
|
|
.modelId(huggingFace.getModelId())
|
|
.waitForModel(huggingFace.getWaitForModel())
|
|
.timeout(huggingFace.getTimeout())
|
|
.build();
|
|
|
|
case LOCAL_AI:
|
|
LocalAi localAi = properties.getEmbeddingModel().getLocalAi();
|
|
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.embedding-model.localai.base-url' property");
|
|
}
|
|
if (isNullOrBlank(localAi.getModelName())) {
|
|
throw illegalConfiguration(
|
|
"\n\nPlease define 'langchain4j.embedding-model.localai.model-name' property");
|
|
}
|
|
return LocalAiEmbeddingModel.builder()
|
|
.baseUrl(localAi.getBaseUrl())
|
|
.modelName(localAi.getModelName())
|
|
.timeout(localAi.getTimeout())
|
|
.maxRetries(localAi.getMaxRetries())
|
|
.logRequests(localAi.getLogRequests())
|
|
.logResponses(localAi.getLogResponses())
|
|
.build();
|
|
default:
|
|
return new AllMiniLmL6V2EmbeddingModel();
|
|
}
|
|
}
|
|
|
|
@Bean
|
|
@Lazy
|
|
@ConditionalOnMissingBean
|
|
ModerationModel moderationModel(LangChain4jProperties properties) {
|
|
if (properties.getModerationModel() == null) {
|
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n"
|
|
+ "langchain4j.moderation-model.provider = openai\n"
|
|
+ "langchain4j.moderation-model.openai.api-key = sk-...\n");
|
|
}
|
|
|
|
if (properties.getModerationModel().getProvider() != OPEN_AI) {
|
|
throw illegalConfiguration("Unsupported moderation model provider: %s",
|
|
properties.getModerationModel().getProvider());
|
|
}
|
|
|
|
OpenAi openAi = properties.getModerationModel().getOpenAi();
|
|
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
|
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model.openai.api-key' property");
|
|
}
|
|
|
|
return OpenAiModerationModel.builder()
|
|
.apiKey(openAi.getApiKey())
|
|
.modelName(openAi.getModelName())
|
|
.timeout(openAi.getTimeout())
|
|
.maxRetries(openAi.getMaxRetries())
|
|
.logRequests(openAi.getLogRequests())
|
|
.logResponses(openAi.getLogResponses())
|
|
.build();
|
|
}
|
|
|
|
} |