Files
supersonic/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java
2025-02-12 20:28:10 +08:00

300 lines
13 KiB
Java

package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.ResponseFormat;
import dev.ai4j.openai4j.chat.ResponseFormatType;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import java.net.Proxy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import static dev.ai4j.openai4j.chat.ResponseFormatType.JSON_SCHEMA;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerRequest;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
* gpt-4. You can find description of parameters
* <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
public static final String ZHIPU = "bigmodel";
private final OpenAiClient client;
private final String baseUrl;
private final String modelName;
private final String apiVersion;
private final Double temperature;
private final Double topP;
private final List<String> stop;
private final Integer maxTokens;
private final Double presencePenalty;
private final Double frequencyPenalty;
private final Map<String, Integer> logitBias;
private final ResponseFormat responseFormat;
private final Boolean strictJsonSchema;
private final Integer seed;
private final String user;
private final Boolean strictTools;
private final Boolean parallelToolCalls;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final List<ChatModelListener> listeners;
@Builder
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName, String apiVersion,
Double temperature, Double topP, List<String> stop, Integer maxTokens,
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Integer maxRetries,
Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer,
Map<String, String> customHeaders, List<ChatModelListener> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
baseUrl = OPENAI_DEMO_URL;
}
this.baseUrl = baseUrl;
timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl).apiVersion(apiVersion)
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders).build();
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.apiVersion = apiVersion;
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
this.stop = stop;
this.maxTokens = maxTokens;
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.logitBias = logitBias;
this.responseFormat = responseFormat == null ? null
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
this.seed = seed;
this.user = user;
this.strictTools = getOrDefault(strictTools, false);
this.parallelToolCalls = parallelToolCalls;
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
}
public String modelName() {
return modelName;
}
@Override
public Set<Capability> supportedCapabilities() {
Set<Capability> capabilities = new HashSet<>();
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA) {
capabilities.add(RESPONSE_FORMAT_JSON_SCHEMA);
}
return capabilities;
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, null, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification,
this.responseFormat);
}
@Override
public ChatResponse chat(ChatRequest request) {
Response<AiMessage> response =
generate(request.messages(), request.toolSpecifications(), null,
getOrDefault(
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
this.responseFormat));
return ChatResponse.builder().aiMessage(response.content())
.tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build();
}
private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted,
ResponseFormat responseFormat) {
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA
&& responseFormat.jsonSchema() == null) {
responseFormat = null;
}
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop)
.maxTokens(maxTokens).presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty).logitBias(logitBias)
.responseFormat(responseFormat).seed(seed).user(user)
.parallelToolCalls(parallelToolCalls);
if (!(baseUrl.contains(ZHIPU))) {
requestBuilder.temperature(temperature);
}
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
requestBuilder.tools(toTools(toolSpecifications, strictTools));
}
if (toolThatMustBeExecuted != null) {
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
}
ChatCompletionRequest request = requestBuilder.build();
ChatModelRequest modelListenerRequest =
createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext =
new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
try {
ChatCompletionResponse chatCompletionResponse =
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
Response<AiMessage> response = Response.from(aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()));
ChatModelResponse modelListenerResponse = createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return response;
} catch (RuntimeException e) {
Throwable error;
if (e.getCause() instanceof OpenAiHttpException) {
error = e.getCause();
} else {
error = e;
}
ChatModelErrorContext errorContext =
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
}
@Override
public int estimateTokenCount(List<ChatMessage> messages) {
return tokenizer.estimateTokenCountInMessages(messages);
}
public static OpenAiChatModel withApiKey(String apiKey) {
return builder().apiKey(apiKey).build();
}
public static OpenAiChatModelBuilder builder() {
for (OpenAiChatModelBuilderFactory factory : loadFactories(
OpenAiChatModelBuilderFactory.class)) {
return factory.get();
}
return new OpenAiChatModelBuilder();
}
public static class OpenAiChatModelBuilder {
public OpenAiChatModelBuilder() {
// This is public so it can be extended
// By default with Lombok it becomes package private
}
public OpenAiChatModelBuilder modelName(String modelName) {
this.modelName = modelName;
return this;
}
public OpenAiChatModelBuilder modelName(OpenAiChatModelName modelName) {
this.modelName = modelName.toString();
return this;
}
}
}