From e9dfb30ccf64924f0b5406e294d47a75dc4df1d2 Mon Sep 17 00:00:00 2001
From: lexluo09 <39718951+lexluo09@users.noreply.github.com>
Date: Tue, 25 Jun 2024 21:42:02 +0800
Subject: [PATCH] (improvement)(common) Fix the issue with querying large
models using the GLM interface. (#1226)
---
.../model/openai/OpenAiChatModel.java | 264 ++++++++++++++++++
.../model/zhipu/ZhipuAiChatModel.java | 114 ++++++++
2 files changed, 378 insertions(+)
create mode 100644 common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java
create mode 100644 common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java
diff --git a/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java b/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java
new file mode 100644
index 000000000..70c76e472
--- /dev/null
+++ b/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java
@@ -0,0 +1,264 @@
+package dev.langchain4j.model.openai;
+
+import dev.ai4j.openai4j.OpenAiClient;
+import dev.ai4j.openai4j.OpenAiHttpException;
+import dev.ai4j.openai4j.chat.ChatCompletionRequest;
+import dev.ai4j.openai4j.chat.ChatCompletionResponse;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.model.Tokenizer;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.chat.TokenCountEstimator;
+import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
+import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
+import dev.langchain4j.model.listener.ModelListener;
+import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
+import dev.langchain4j.model.output.Response;
+import lombok.Builder;
+import lombok.extern.slf4j.Slf4j;
+
+import java.net.Proxy;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import static dev.langchain4j.internal.RetryUtils.withRetry;
+import static dev.langchain4j.internal.Utils.getOrDefault;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerRequest;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
+import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
+import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
+import static dev.langchain4j.spi.ServiceHelper.loadFactories;
+import static java.time.Duration.ofSeconds;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.singletonList;
+
+/**
+ * Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
+ * You can find description of parameters here.
+ */
+@Slf4j
+public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
+
+ public static final String ZHIPU = "bigmodel";
+ private final OpenAiClient client;
+ private final String baseUrl;
+ private final String modelName;
+ private final Double temperature;
+ private final Double topP;
+ private final List stop;
+ private final Integer maxTokens;
+ private final Double presencePenalty;
+ private final Double frequencyPenalty;
+ private final Map logitBias;
+ private final String responseFormat;
+ private final Integer seed;
+ private final String user;
+ private final Integer maxRetries;
+ private final Tokenizer tokenizer;
+
+ private final List> listeners;
+
+ @Builder
+ public OpenAiChatModel(String baseUrl,
+ String apiKey,
+ String organizationId,
+ String modelName,
+ Double temperature,
+ Double topP,
+ List stop,
+ Integer maxTokens,
+ Double presencePenalty,
+ Double frequencyPenalty,
+ Map logitBias,
+ String responseFormat,
+ Integer seed,
+ String user,
+ Duration timeout,
+ Integer maxRetries,
+ Proxy proxy,
+ Boolean logRequests,
+ Boolean logResponses,
+ Tokenizer tokenizer,
+ Map customHeaders,
+ List> listeners) {
+
+ baseUrl = getOrDefault(baseUrl, OPENAI_URL);
+ if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
+ baseUrl = OPENAI_DEMO_URL;
+ }
+ this.baseUrl = baseUrl;
+
+ timeout = getOrDefault(timeout, ofSeconds(60));
+
+ this.client = OpenAiClient.builder()
+ .openAiApiKey(apiKey)
+ .baseUrl(baseUrl)
+ .organizationId(organizationId)
+ .callTimeout(timeout)
+ .connectTimeout(timeout)
+ .readTimeout(timeout)
+ .writeTimeout(timeout)
+ .proxy(proxy)
+ .logRequests(logRequests)
+ .logResponses(logResponses)
+ .userAgent(DEFAULT_USER_AGENT)
+ .customHeaders(customHeaders)
+ .build();
+ this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
+ this.temperature = getOrDefault(temperature, 0.7);
+ this.topP = topP;
+ this.stop = stop;
+ this.maxTokens = maxTokens;
+ this.presencePenalty = presencePenalty;
+ this.frequencyPenalty = frequencyPenalty;
+ this.logitBias = logitBias;
+ this.responseFormat = responseFormat;
+ this.seed = seed;
+ this.user = user;
+ this.maxRetries = getOrDefault(maxRetries, 3);
+ this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
+ this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
+ }
+
+ public String modelName() {
+ return modelName;
+ }
+
+ @Override
+ public Response generate(List messages) {
+ return generate(messages, null, null);
+ }
+
+ @Override
+ public Response generate(List messages, List toolSpecifications) {
+ return generate(messages, toolSpecifications, null);
+ }
+
+ @Override
+ public Response generate(List messages, ToolSpecification toolSpecification) {
+ return generate(messages, singletonList(toolSpecification), toolSpecification);
+ }
+
+ private Response generate(List messages,
+ List toolSpecifications,
+ ToolSpecification toolThatMustBeExecuted
+ ) {
+ ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
+ .model(modelName)
+ .messages(toOpenAiMessages(messages))
+ .topP(topP)
+ .stop(stop)
+ .maxTokens(maxTokens)
+ .presencePenalty(presencePenalty)
+ .frequencyPenalty(frequencyPenalty)
+ .logitBias(logitBias)
+ .responseFormat(responseFormat)
+ .seed(seed)
+ .user(user);
+ if (!(baseUrl.contains(ZHIPU))) {
+ requestBuilder.temperature(temperature);
+ }
+
+ if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
+ requestBuilder.tools(toTools(toolSpecifications));
+ }
+ if (toolThatMustBeExecuted != null) {
+ requestBuilder.toolChoice(toolThatMustBeExecuted.name());
+ }
+
+ ChatCompletionRequest request = requestBuilder.build();
+
+ ChatLanguageModelRequest modelListenerRequest =
+ createModelListenerRequest(request, messages, toolSpecifications);
+ listeners.forEach(listener -> {
+ try {
+ listener.onRequest(modelListenerRequest);
+ } catch (Exception e) {
+ log.warn("Exception while calling model listener", e);
+ }
+ });
+
+ try {
+ ChatCompletionResponse chatCompletionResponse =
+ withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
+
+ Response response = Response.from(
+ aiMessageFrom(chatCompletionResponse),
+ tokenUsageFrom(chatCompletionResponse.usage()),
+ finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason())
+ );
+
+ ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
+ chatCompletionResponse.id(),
+ chatCompletionResponse.model(),
+ response
+ );
+ listeners.forEach(listener -> {
+ try {
+ listener.onResponse(modelListenerResponse, modelListenerRequest);
+ } catch (Exception e) {
+ log.warn("Exception while calling model listener", e);
+ }
+ });
+
+ return response;
+ } catch (RuntimeException e) {
+
+ Throwable error;
+ if (e.getCause() instanceof OpenAiHttpException) {
+ error = e.getCause();
+ } else {
+ error = e;
+ }
+
+ listeners.forEach(listener -> {
+ try {
+ listener.onError(error, null, modelListenerRequest);
+ } catch (Exception e2) {
+ log.warn("Exception while calling model listener", e2);
+ }
+ });
+ throw e;
+ }
+ }
+
+ @Override
+ public int estimateTokenCount(List messages) {
+ return tokenizer.estimateTokenCountInMessages(messages);
+ }
+
+ public static OpenAiChatModel withApiKey(String apiKey) {
+ return builder().apiKey(apiKey).build();
+ }
+
+ public static OpenAiChatModelBuilder builder() {
+ for (OpenAiChatModelBuilderFactory factory : loadFactories(OpenAiChatModelBuilderFactory.class)) {
+ return factory.get();
+ }
+ return new OpenAiChatModelBuilder();
+ }
+
+ public static class OpenAiChatModelBuilder {
+
+ public OpenAiChatModelBuilder() {
+ // This is public so it can be extended
+ // By default with Lombok it becomes package private
+ }
+
+ public OpenAiChatModelBuilder modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+ }
+}
\ No newline at end of file
diff --git a/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java b/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java
new file mode 100644
index 000000000..fbdb39061
--- /dev/null
+++ b/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java
@@ -0,0 +1,114 @@
+package dev.langchain4j.model.zhipu;
+
+import dev.ai4j.openai4j.chat.ChatCompletionModel;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest;
+import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse;
+import dev.langchain4j.model.zhipu.spi.ZhipuAiChatModelBuilderFactory;
+import lombok.Builder;
+
+import java.util.List;
+
+import static dev.langchain4j.internal.RetryUtils.withRetry;
+import static dev.langchain4j.internal.Utils.getOrDefault;
+import static dev.langchain4j.internal.Utils.isNullOrEmpty;
+import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
+import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.aiMessageFrom;
+import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.finishReasonFrom;
+import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toTools;
+import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toZhipuAiMessages;
+import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.tokenUsageFrom;
+import static dev.langchain4j.model.zhipu.chat.ToolChoiceMode.AUTO;
+import static dev.langchain4j.spi.ServiceHelper.loadFactories;
+import static java.util.Collections.singletonList;
+
+/**
+ * Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and glm-4.
+ * You can find description of parameters here.
+ */
+public class ZhipuAiChatModel implements ChatLanguageModel {
+
+ private final String baseUrl;
+ private final Double temperature;
+ private final Double topP;
+ private final String model;
+ private final Integer maxRetries;
+ private final Integer maxToken;
+ private final ZhipuAiClient client;
+
+ @Builder
+ public ZhipuAiChatModel(
+ String baseUrl,
+ String apiKey,
+ Double temperature,
+ Double topP,
+ String model,
+ Integer maxRetries,
+ Integer maxToken,
+ Boolean logRequests,
+ Boolean logResponses
+ ) {
+ this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
+ this.temperature = getOrDefault(temperature, 0.7);
+ this.topP = topP;
+ this.model = getOrDefault(model, ChatCompletionModel.GPT_4.toString());
+ this.maxRetries = getOrDefault(maxRetries, 3);
+ this.maxToken = getOrDefault(maxToken, 512);
+ this.client = ZhipuAiClient.builder()
+ .baseUrl(this.baseUrl)
+ .apiKey(apiKey)
+ .logRequests(getOrDefault(logRequests, false))
+ .logResponses(getOrDefault(logResponses, false))
+ .build();
+ }
+
+ public static ZhipuAiChatModelBuilder builder() {
+ for (ZhipuAiChatModelBuilderFactory factories : loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
+ return factories.get();
+ }
+ return new ZhipuAiChatModelBuilder();
+ }
+
+ @Override
+ public Response generate(List messages) {
+ return generate(messages, (ToolSpecification) null);
+ }
+
+ @Override
+ public Response generate(List messages, List toolSpecifications) {
+ ensureNotEmpty(messages, "messages");
+
+ ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
+ .model(this.model)
+ .maxTokens(maxToken)
+ .stream(false)
+ .topP(topP)
+ .toolChoice(AUTO)
+ .messages(toZhipuAiMessages(messages));
+
+ if (!isNullOrEmpty(toolSpecifications)) {
+ requestBuilder.tools(toTools(toolSpecifications));
+ }
+
+ ChatCompletionResponse response = withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
+ return Response.from(
+ aiMessageFrom(response),
+ tokenUsageFrom(response.getUsage()),
+ finishReasonFrom(response.getChoices().get(0).getFinishReason())
+ );
+ }
+
+ @Override
+ public Response generate(List messages, ToolSpecification toolSpecification) {
+ return generate(messages, toolSpecification != null ? singletonList(toolSpecification) : null);
+ }
+
+ public static class ZhipuAiChatModelBuilder {
+ public ZhipuAiChatModelBuilder() {
+ }
+ }
+}