[Improvement][headless] Only expose OPEN_AI/OLLAMA/DIFY chat model providers. (#2237)

This commit is contained in:
beat4ocean
2025-05-05 15:51:26 +08:00
committed by GitHub
parent acffc03c79
commit c4992501bd
26 changed files with 25 additions and 980 deletions

View File

@@ -1,23 +0,0 @@
package dev.langchain4j.dashscope.spring;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
@Getter
@Setter
class ChatModelProperties {
String baseUrl;
String apiKey;
String modelName;
Double topP;
Integer topK;
Boolean enableSearch;
Integer seed;
Float repetitionPenalty;
Float temperature;
List<String> stops;
Integer maxTokens;
}

View File

@@ -1,84 +0,0 @@
package dev.langchain4j.dashscope.spring;
import dev.langchain4j.model.dashscope.QwenChatModel;
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
import dev.langchain4j.model.dashscope.QwenLanguageModel;
import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
import dev.langchain4j.model.dashscope.QwenStreamingLanguageModel;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.dashscope.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class DashscopeAutoConfig {
@Bean
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
QwenChatModel qwenChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
.topK(chatModelProperties.getTopK())
.enableSearch(chatModelProperties.getEnableSearch())
.seed(chatModelProperties.getSeed())
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
.temperature(chatModelProperties.getTemperature())
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
.build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
QwenStreamingChatModel qwenStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
.topK(chatModelProperties.getTopK())
.enableSearch(chatModelProperties.getEnableSearch())
.seed(chatModelProperties.getSeed())
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
.temperature(chatModelProperties.getTemperature())
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
.build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
QwenLanguageModel qwenLanguageModel(Properties properties) {
ChatModelProperties languageModel = properties.getLanguageModel();
return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
.repetitionPenalty(languageModel.getRepetitionPenalty())
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) {
ChatModelProperties languageModel = properties.getStreamingLanguageModel();
return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
.repetitionPenalty(languageModel.getRepetitionPenalty())
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
QwenEmbeddingModel qwenEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey())
.modelName(embeddingModelProperties.getModelName()).build();
}
}

View File

@@ -1,12 +0,0 @@
package dev.langchain4j.dashscope.spring;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
class EmbeddingModelProperties {
private String apiKey;
private String modelName;
}

View File

@@ -1,29 +0,0 @@
package dev.langchain4j.dashscope.spring;
import lombok.Getter;
import lombok.Setter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
@Getter
@Setter
@ConfigurationProperties(prefix = Properties.PREFIX)
public class Properties {
static final String PREFIX = "langchain4j.dashscope";
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
ChatModelProperties languageModel;
@NestedConfigurationProperty
ChatModelProperties streamingLanguageModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -66,7 +66,6 @@ import static java.util.Collections.singletonList;
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
public static final String ZHIPU = "bigmodel";
private final OpenAiClient client;
private final String baseUrl;
private final String modelName;
@@ -192,9 +191,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
.responseFormat(responseFormat).seed(seed).user(user)
.parallelToolCalls(parallelToolCalls);
if (!(baseUrl.contains(ZHIPU))) {
requestBuilder.temperature(temperature);
}
requestBuilder.temperature(temperature);
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
requestBuilder.tools(toTools(toolSpecifications, strictTools));

View File

@@ -1,16 +0,0 @@
package dev.langchain4j.model.zhipu;
public enum ChatCompletionModel {
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
private final String value;
ChatCompletionModel(String value) {
this.value = value;
}
@Override
public String toString() {
return this.value;
}
}

View File

@@ -1,100 +0,0 @@
package dev.langchain4j.model.zhipu;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest;
import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse;
import dev.langchain4j.model.zhipu.spi.ZhipuAiChatModelBuilderFactory;
import lombok.Builder;
import java.util.List;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.aiMessageFrom;
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.finishReasonFrom;
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toTools;
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toZhipuAiMessages;
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.tokenUsageFrom;
import static dev.langchain4j.model.zhipu.chat.ToolChoiceMode.AUTO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.Collections.singletonList;
/**
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
* glm-4. You can find description of parameters
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
*/
public class ZhipuAiChatModel implements ChatLanguageModel {
private final String baseUrl;
private final Double temperature;
private final Double topP;
private final String model;
private final Integer maxRetries;
private final Integer maxToken;
private final ZhipuAiClient client;
@Builder
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
Boolean logResponses) {
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
this.maxRetries = getOrDefault(maxRetries, 3);
this.maxToken = getOrDefault(maxToken, 512);
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false)).build();
}
public static ZhipuAiChatModelBuilder builder() {
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
ZhipuAiChatModelBuilderFactory.class)) {
return factories.get();
}
return new ZhipuAiChatModelBuilder();
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, (ToolSpecification) null);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
}
ChatCompletionResponse response =
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages,
toolSpecification != null ? singletonList(toolSpecification) : null);
}
public static class ZhipuAiChatModelBuilder {
public ZhipuAiChatModelBuilder() {}
}
}

View File

@@ -1,51 +0,0 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import java.time.Duration;
@Service
public class AzureModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "AZURE";
public static final String DEFAULT_BASE_URL = "https://your-resource-name.openai.azure.com/";
public static final String DEFAULT_MODEL_NAME = "gpt-35-turbo";
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(Duration.ofSeconds(
modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
.logRequestsAndResponses(
modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
return builder.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
AzureOpenAiEmbeddingModel.Builder builder =
AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.deploymentName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
&& embeddingModelConfig.getLogResponses());
return builder.build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -1,39 +0,0 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.dashscope.QwenChatModel;
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
import dev.langchain4j.model.dashscope.QwenModelName;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@Service
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE";
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
public static final String DEFAULT_MODEL_NAME = QwenModelName.QWEN_PLUS;
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-v2";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature() == null ? 0L
: modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey())
.modelName(embeddingModelConfig.getModelName()).build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.common.util.AESEncryptionUtil;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.dify.DifyAiChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -27,8 +27,9 @@ public class DifyModelFactory implements ModelFactory, InitializingBean {
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses()).build();

View File

@@ -1,47 +0,0 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.qianfan.QianfanChatModel;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@Service
public class QianfanModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "QIANFAN";
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
public static final String DEFAULT_MODEL_NAME = "Llama-2-70b-chat";
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "Embedding-V1";
public static final String DEFAULT_ENDPOINT = "llama_2_70b";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey())
.endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.secretKey(embeddingModelConfig.getSecretKey())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -1,45 +0,0 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.zhipu.ChatCompletionModel;
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import static java.time.Duration.ofSeconds;
@Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU";
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
public static final String DEFAULT_MODEL_NAME = ChatCompletionModel.GLM_4.toString();
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "embedding-2";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
.readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(PROVIDER, this);
}
}

View File

@@ -1,21 +0,0 @@
package dev.langchain4j.qianfan.spring;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
class ChatModelProperties {
private String baseUrl;
private String apiKey;
private String secretKey;
private Double temperature;
private Integer maxRetries;
private Double topP;
private String modelName;
private String endpoint;
private String responseFormat;
private Double penaltyScore;
private Boolean logRequests;
private Boolean logResponses;
}

View File

@@ -1,18 +0,0 @@
package dev.langchain4j.qianfan.spring;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
class EmbeddingModelProperties {
private String baseUrl;
private String apiKey;
private String secretKey;
private Integer maxRetries;
private String modelName;
private String endpoint;
private String user;
private Boolean logRequests;
private Boolean logResponses;
}

View File

@@ -1,21 +0,0 @@
package dev.langchain4j.qianfan.spring;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
class LanguageModelProperties {
private String baseUrl;
private String apiKey;
private String secretKey;
private Double temperature;
private Integer maxRetries;
private Integer topK;
private Double topP;
private String modelName;
private String endpoint;
private Double penaltyScore;
private Boolean logRequests;
private Boolean logResponses;
}

View File

@@ -1,29 +0,0 @@
package dev.langchain4j.qianfan.spring;
import lombok.Getter;
import lombok.Setter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
@Getter
@Setter
@ConfigurationProperties(prefix = Properties.PREFIX)
public class Properties {
static final String PREFIX = "langchain4j.qianfan";
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
LanguageModelProperties languageModel;
@NestedConfigurationProperty
LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -1,102 +0,0 @@
package dev.langchain4j.qianfan.spring;
import dev.langchain4j.model.qianfan.QianfanChatModel;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
import dev.langchain4j.model.qianfan.QianfanLanguageModel;
import dev.langchain4j.model.qianfan.QianfanStreamingChatModel;
import dev.langchain4j.model.qianfan.QianfanStreamingLanguageModel;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.qianfan.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class QianfanAutoConfig {
@Bean
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
QianfanChatModel qianfanChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return QianfanChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.secretKey(chatModelProperties.getSecretKey())
.endpoint(chatModelProperties.getEndpoint())
.penaltyScore(chatModelProperties.getPenaltyScore())
.modelName(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.responseFormat(chatModelProperties.getResponseFormat())
.maxRetries(chatModelProperties.getMaxRetries())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
QianfanStreamingChatModel qianfanStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return QianfanStreamingChatModel.builder().endpoint(chatModelProperties.getEndpoint())
.penaltyScore(chatModelProperties.getPenaltyScore())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.secretKey(chatModelProperties.getSecretKey())
.modelName(chatModelProperties.getModelName())
.responseFormat(chatModelProperties.getResponseFormat())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
QianfanLanguageModel qianfanLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
return QianfanLanguageModel.builder().endpoint(languageModelProperties.getEndpoint())
.penaltyScore(languageModelProperties.getPenaltyScore())
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
.baseUrl(languageModelProperties.getBaseUrl())
.apiKey(languageModelProperties.getApiKey())
.secretKey(languageModelProperties.getSecretKey())
.modelName(languageModelProperties.getModelName())
.temperature(languageModelProperties.getTemperature())
.maxRetries(languageModelProperties.getMaxRetries())
.logRequests(languageModelProperties.getLogRequests())
.logResponses(languageModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
QianfanStreamingLanguageModel qianfanStreamingLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
return QianfanStreamingLanguageModel.builder()
.endpoint(languageModelProperties.getEndpoint())
.penaltyScore(languageModelProperties.getPenaltyScore())
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
.baseUrl(languageModelProperties.getBaseUrl())
.apiKey(languageModelProperties.getApiKey())
.secretKey(languageModelProperties.getSecretKey())
.modelName(languageModelProperties.getModelName())
.temperature(languageModelProperties.getTemperature())
.maxRetries(languageModelProperties.getMaxRetries())
.logRequests(languageModelProperties.getLogRequests())
.logResponses(languageModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
QianfanEmbeddingModel qianfanEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
.endpoint(embeddingModelProperties.getEndpoint())
.apiKey(embeddingModelProperties.getApiKey())
.secretKey(embeddingModelProperties.getSecretKey())
.modelName(embeddingModelProperties.getModelName())
.user(embeddingModelProperties.getUser())
.maxRetries(embeddingModelProperties.getMaxRetries())
.logRequests(embeddingModelProperties.getLogRequests())
.logResponses(embeddingModelProperties.getLogResponses()).build();
}
}

View File

@@ -1,19 +0,0 @@
package dev.langchain4j.zhipu.spring;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
class ChatModelProperties {
String baseUrl;
String apiKey;
Double temperature;
Double topP;
String modelName;
Integer maxRetries;
Integer maxToken;
Boolean logRequests;
Boolean logResponses;
}

View File

@@ -1,16 +0,0 @@
package dev.langchain4j.zhipu.spring;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
class EmbeddingModelProperties {
String baseUrl;
String apiKey;
String model;
Integer maxRetries;
Boolean logRequests;
Boolean logResponses;
}

View File

@@ -1,23 +0,0 @@
package dev.langchain4j.zhipu.spring;
import lombok.Getter;
import lombok.Setter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
@Getter
@Setter
@ConfigurationProperties(prefix = Properties.PREFIX)
public class Properties {
static final String PREFIX = "langchain4j.zhipu";
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -1,53 +0,0 @@
package dev.langchain4j.zhipu.spring;
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import dev.langchain4j.model.zhipu.ZhipuAiStreamingChatModel;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.zhipu.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class ZhipuAutoConfig {
@Bean
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
ZhipuAiChatModel zhipuAiChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return ZhipuAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
.maxToken(chatModelProperties.getMaxToken())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
ZhipuAiStreamingChatModel zhipuStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return ZhipuAiStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP()).maxToken(chatModelProperties.getMaxToken())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
ZhipuAiEmbeddingModel zhipuEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
.apiKey(embeddingModelProperties.getApiKey())
.model(embeddingModelProperties.getModel())
.maxRetries(embeddingModelProperties.getMaxRetries())
.logRequests(embeddingModelProperties.getLogRequests())
.logResponses(embeddingModelProperties.getLogResponses()).build();
}
}