(improvement)[build] Use Spotless to customize the code formatting (#1750)

This commit is contained in:
lexluo09
2024-10-04 00:05:04 +08:00
committed by GitHub
parent 44d1cde34f
commit 71a9954be5
521 changed files with 7811 additions and 13046 deletions

View File

@@ -60,8 +60,8 @@ import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
* gpt-4. You can find description of parameters <a
* href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
* gpt-4. You can find description of parameters
* <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
@@ -88,32 +88,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final List<ChatModelListener> listeners;
@Builder
public OpenAiChatModel(
String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
List<String> stop,
Integer maxTokens,
Double presencePenalty,
Double frequencyPenalty,
Map<String, Integer> logitBias,
String responseFormat,
Boolean strictJsonSchema,
Integer seed,
String user,
Boolean strictTools,
Boolean parallelToolCalls,
Duration timeout,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ChatModelListener> listeners) {
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName,
Double temperature, Double topP, List<String> stop, Integer maxTokens,
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Integer maxRetries,
Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer,
Map<String, String> customHeaders, List<ChatModelListener> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
@@ -123,21 +104,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
timeout = getOrDefault(timeout, ofSeconds(60));
this.client =
OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.proxy(proxy)
.logRequests(logRequests)
.logResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders)
.build();
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl)
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders).build();
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
@@ -146,14 +117,10 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.logitBias = logitBias;
this.responseFormat =
responseFormat == null
? null
: ResponseFormat.builder()
.type(
ResponseFormatType.valueOf(
responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.responseFormat = responseFormat == null ? null
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
this.seed = seed;
this.user = user;
@@ -183,61 +150,44 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(
messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification,
this.responseFormat);
}
@Override
public ChatResponse chat(ChatRequest request) {
Response<AiMessage> response =
generate(
request.messages(),
request.toolSpecifications(),
null,
generate(request.messages(), request.toolSpecifications(), null,
getOrDefault(
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
this.responseFormat));
return ChatResponse.builder()
.aiMessage(response.content())
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.build();
return ChatResponse.builder().aiMessage(response.content())
.tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build();
}
private Response<AiMessage> generate(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted,
private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted,
ResponseFormat responseFormat) {
if (responseFormat != null
&& responseFormat.type() == JSON_SCHEMA
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA
&& responseFormat.jsonSchema() == null) {
responseFormat = null;
}
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
.topP(topP)
.stop(stop)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user)
.parallelToolCalls(parallelToolCalls);
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop)
.maxTokens(maxTokens).presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty).logitBias(logitBias)
.responseFormat(responseFormat).seed(seed).user(user)
.parallelToolCalls(parallelToolCalls);
if (!(baseUrl.contains(ZHIPU))) {
requestBuilder.temperature(temperature);
@@ -257,40 +207,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext =
new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
try {
ChatCompletionResponse chatCompletionResponse =
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
Response<AiMessage> response =
Response.from(
aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(
chatCompletionResponse.choices().get(0).finishReason()));
Response<AiMessage> response = Response.from(aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()));
ChatModelResponse modelListenerResponse =
createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext =
new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
ChatModelResponse modelListenerResponse = createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return response;
} catch (RuntimeException e) {
@@ -305,14 +248,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
ChatModelErrorContext errorContext =
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
listeners.forEach(
listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
@@ -328,8 +270,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
public static OpenAiChatModelBuilder builder() {
for (OpenAiChatModelBuilderFactory factory :
loadFactories(OpenAiChatModelBuilderFactory.class)) {
for (OpenAiChatModelBuilderFactory factory : loadFactories(
OpenAiChatModelBuilderFactory.class)) {
return factory.get();
}
return new OpenAiChatModelBuilder();