mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
298 lines
13 KiB
Java
298 lines
13 KiB
Java
package dev.langchain4j.model.openai;
|
|
|
|
import dev.ai4j.openai4j.OpenAiClient;
|
|
import dev.ai4j.openai4j.OpenAiHttpException;
|
|
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
|
|
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
|
|
import dev.ai4j.openai4j.chat.ResponseFormat;
|
|
import dev.ai4j.openai4j.chat.ResponseFormatType;
|
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
|
import dev.langchain4j.data.message.AiMessage;
|
|
import dev.langchain4j.data.message.ChatMessage;
|
|
import dev.langchain4j.model.Tokenizer;
|
|
import dev.langchain4j.model.chat.Capability;
|
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
import dev.langchain4j.model.chat.TokenCountEstimator;
|
|
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
|
|
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
|
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
|
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
|
import dev.langchain4j.model.chat.listener.ChatModelResponse;
|
|
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
|
|
import dev.langchain4j.model.chat.request.ChatRequest;
|
|
import dev.langchain4j.model.chat.response.ChatResponse;
|
|
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
|
|
import dev.langchain4j.model.output.Response;
|
|
import lombok.Builder;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import java.net.Proxy;
|
|
import java.time.Duration;
|
|
import java.util.ArrayList;
|
|
import java.util.HashSet;
|
|
import java.util.List;
|
|
import java.util.Locale;
|
|
import java.util.Map;
|
|
import java.util.Set;
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
|
|
import static dev.ai4j.openai4j.chat.ResponseFormatType.JSON_SCHEMA;
|
|
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
|
import static dev.langchain4j.internal.Utils.getOrDefault;
|
|
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerRequest;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
|
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
|
|
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
|
import static java.time.Duration.ofSeconds;
|
|
import static java.util.Collections.emptyList;
|
|
import static java.util.Collections.singletonList;
|
|
|
|
/**
|
|
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
|
|
* gpt-4. You can find description of parameters
|
|
* <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
|
*/
|
|
@Slf4j
|
|
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|
|
|
private final OpenAiClient client;
|
|
private final String baseUrl;
|
|
private final String modelName;
|
|
private final String apiVersion;
|
|
private final Double temperature;
|
|
private final Double topP;
|
|
private final List<String> stop;
|
|
private final Integer maxTokens;
|
|
private final Double presencePenalty;
|
|
private final Double frequencyPenalty;
|
|
private final Map<String, Integer> logitBias;
|
|
private final ResponseFormat responseFormat;
|
|
private final Boolean strictJsonSchema;
|
|
private final Integer seed;
|
|
private final String user;
|
|
private final Boolean strictTools;
|
|
private final Boolean parallelToolCalls;
|
|
private final Integer maxRetries;
|
|
private final Tokenizer tokenizer;
|
|
private final List<ChatModelListener> listeners;
|
|
|
|
@Builder
|
|
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName,
|
|
String apiVersion, Double temperature, Double topP, List<String> stop,
|
|
Integer maxTokens, Double presencePenalty, Double frequencyPenalty,
|
|
Map<String, Integer> logitBias, String responseFormat, Boolean strictJsonSchema,
|
|
Integer seed, String user, Boolean strictTools, Boolean parallelToolCalls,
|
|
Duration timeout, Integer maxRetries, Proxy proxy, Boolean logRequests,
|
|
Boolean logResponses, Tokenizer tokenizer, Map<String, String> customHeaders,
|
|
List<ChatModelListener> listeners) {
|
|
|
|
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
|
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
|
baseUrl = OPENAI_DEMO_URL;
|
|
}
|
|
this.baseUrl = baseUrl;
|
|
|
|
timeout = getOrDefault(timeout, ofSeconds(60));
|
|
|
|
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl)
|
|
.apiVersion(apiVersion).organizationId(organizationId).callTimeout(timeout)
|
|
.connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
|
.logRequests(logRequests).logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
|
.customHeaders(customHeaders).build();
|
|
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO.name());
|
|
this.apiVersion = apiVersion;
|
|
this.temperature = getOrDefault(temperature, 0.7);
|
|
this.topP = topP;
|
|
this.stop = stop;
|
|
this.maxTokens = maxTokens;
|
|
this.presencePenalty = presencePenalty;
|
|
this.frequencyPenalty = frequencyPenalty;
|
|
this.logitBias = logitBias;
|
|
this.responseFormat = responseFormat == null ? null
|
|
: ResponseFormat.builder()
|
|
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
|
|
.build();
|
|
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
|
|
this.seed = seed;
|
|
this.user = user;
|
|
this.strictTools = getOrDefault(strictTools, false);
|
|
this.parallelToolCalls = parallelToolCalls;
|
|
this.maxRetries = getOrDefault(maxRetries, 3);
|
|
this.tokenizer = getOrDefault(tokenizer, () -> new OpenAiTokenizer(this.modelName));
|
|
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
|
}
|
|
|
|
public String modelName() {
|
|
return modelName;
|
|
}
|
|
|
|
@Override
|
|
public Set<Capability> supportedCapabilities() {
|
|
Set<Capability> capabilities = new HashSet<>();
|
|
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA) {
|
|
capabilities.add(RESPONSE_FORMAT_JSON_SCHEMA);
|
|
}
|
|
return capabilities;
|
|
}
|
|
|
|
@Override
|
|
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
|
return generate(messages, null, null, this.responseFormat);
|
|
}
|
|
|
|
@Override
|
|
public Response<AiMessage> generate(List<ChatMessage> messages,
|
|
List<ToolSpecification> toolSpecifications) {
|
|
return generate(messages, toolSpecifications, null, this.responseFormat);
|
|
}
|
|
|
|
@Override
|
|
public Response<AiMessage> generate(List<ChatMessage> messages,
|
|
ToolSpecification toolSpecification) {
|
|
return generate(messages, singletonList(toolSpecification), toolSpecification,
|
|
this.responseFormat);
|
|
}
|
|
|
|
@Override
|
|
public ChatResponse chat(ChatRequest request) {
|
|
Response<AiMessage> response =
|
|
generate(request.messages(), request.toolSpecifications(), null,
|
|
getOrDefault(
|
|
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
|
|
this.responseFormat));
|
|
return ChatResponse.builder().aiMessage(response.content())
|
|
.tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build();
|
|
}
|
|
|
|
private Response<AiMessage> generate(List<ChatMessage> messages,
|
|
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted,
|
|
ResponseFormat responseFormat) {
|
|
|
|
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA
|
|
&& responseFormat.jsonSchema() == null) {
|
|
responseFormat = null;
|
|
}
|
|
|
|
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
|
|
.model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop)
|
|
.maxTokens(maxTokens).presencePenalty(presencePenalty)
|
|
.frequencyPenalty(frequencyPenalty).logitBias(logitBias)
|
|
.responseFormat(responseFormat).seed(seed).user(user)
|
|
.parallelToolCalls(parallelToolCalls);
|
|
|
|
requestBuilder.temperature(temperature);
|
|
|
|
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
|
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
|
}
|
|
if (toolThatMustBeExecuted != null) {
|
|
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
|
|
}
|
|
|
|
ChatCompletionRequest request = requestBuilder.build();
|
|
|
|
ChatModelRequest modelListenerRequest =
|
|
createModelListenerRequest(request, messages, toolSpecifications);
|
|
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
|
ChatModelRequestContext requestContext =
|
|
new ChatModelRequestContext(modelListenerRequest, attributes);
|
|
listeners.forEach(listener -> {
|
|
try {
|
|
listener.onRequest(requestContext);
|
|
} catch (Exception e) {
|
|
log.warn("Exception while calling model listener", e);
|
|
}
|
|
});
|
|
|
|
try {
|
|
ChatCompletionResponse chatCompletionResponse =
|
|
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
|
|
|
|
Response<AiMessage> response = Response.from(aiMessageFrom(chatCompletionResponse),
|
|
tokenUsageFrom(chatCompletionResponse.usage()),
|
|
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()));
|
|
|
|
ChatModelResponse modelListenerResponse = createModelListenerResponse(
|
|
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
|
|
ChatModelResponseContext responseContext = new ChatModelResponseContext(
|
|
modelListenerResponse, modelListenerRequest, attributes);
|
|
listeners.forEach(listener -> {
|
|
try {
|
|
listener.onResponse(responseContext);
|
|
} catch (Exception e) {
|
|
log.warn("Exception while calling model listener", e);
|
|
}
|
|
});
|
|
|
|
return response;
|
|
} catch (RuntimeException e) {
|
|
|
|
Throwable error;
|
|
if (e.getCause() instanceof OpenAiHttpException) {
|
|
error = e.getCause();
|
|
} else {
|
|
error = e;
|
|
}
|
|
|
|
ChatModelErrorContext errorContext =
|
|
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
|
|
|
|
listeners.forEach(listener -> {
|
|
try {
|
|
listener.onError(errorContext);
|
|
} catch (Exception e2) {
|
|
log.warn("Exception while calling model listener", e2);
|
|
}
|
|
});
|
|
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public int estimateTokenCount(List<ChatMessage> messages) {
|
|
return tokenizer.estimateTokenCountInMessages(messages);
|
|
}
|
|
|
|
public static OpenAiChatModel withApiKey(String apiKey) {
|
|
return builder().apiKey(apiKey).build();
|
|
}
|
|
|
|
public static OpenAiChatModelBuilder builder() {
|
|
for (OpenAiChatModelBuilderFactory factory : loadFactories(
|
|
OpenAiChatModelBuilderFactory.class)) {
|
|
return factory.get();
|
|
}
|
|
return new OpenAiChatModelBuilder();
|
|
}
|
|
|
|
public static class OpenAiChatModelBuilder {
|
|
|
|
public OpenAiChatModelBuilder() {
|
|
// This is public so it can be extended
|
|
// By default with Lombok it becomes package private
|
|
}
|
|
|
|
public OpenAiChatModelBuilder modelName(String modelName) {
|
|
this.modelName = modelName;
|
|
return this;
|
|
}
|
|
|
|
public OpenAiChatModelBuilder modelName(OpenAiChatModelName modelName) {
|
|
this.modelName = modelName.toString();
|
|
return this;
|
|
}
|
|
}
|
|
}
|