mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
(improvement)[build] Use Spotless to customize the code formatting (#1750)
This commit is contained in:
@@ -60,8 +60,8 @@ import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
|
||||
* gpt-4. You can find description of parameters <a
|
||||
* href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
* gpt-4. You can find description of parameters
|
||||
* <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
*/
|
||||
@Slf4j
|
||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
@@ -88,32 +88,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
@Builder
|
||||
public OpenAiChatModel(
|
||||
String baseUrl,
|
||||
String apiKey,
|
||||
String organizationId,
|
||||
String modelName,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
List<String> stop,
|
||||
Integer maxTokens,
|
||||
Double presencePenalty,
|
||||
Double frequencyPenalty,
|
||||
Map<String, Integer> logitBias,
|
||||
String responseFormat,
|
||||
Boolean strictJsonSchema,
|
||||
Integer seed,
|
||||
String user,
|
||||
Boolean strictTools,
|
||||
Boolean parallelToolCalls,
|
||||
Duration timeout,
|
||||
Integer maxRetries,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders,
|
||||
List<ChatModelListener> listeners) {
|
||||
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName,
|
||||
Double temperature, Double topP, List<String> stop, Integer maxTokens,
|
||||
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
|
||||
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
|
||||
Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Integer maxRetries,
|
||||
Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders, List<ChatModelListener> listeners) {
|
||||
|
||||
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
||||
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
||||
@@ -123,21 +104,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client =
|
||||
OpenAiClient.builder()
|
||||
.openAiApiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.organizationId(organizationId)
|
||||
.callTimeout(timeout)
|
||||
.connectTimeout(timeout)
|
||||
.readTimeout(timeout)
|
||||
.writeTimeout(timeout)
|
||||
.proxy(proxy)
|
||||
.logRequests(logRequests)
|
||||
.logResponses(logResponses)
|
||||
.userAgent(DEFAULT_USER_AGENT)
|
||||
.customHeaders(customHeaders)
|
||||
.build();
|
||||
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl)
|
||||
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
|
||||
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
|
||||
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
||||
.customHeaders(customHeaders).build();
|
||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
@@ -146,14 +117,10 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
this.presencePenalty = presencePenalty;
|
||||
this.frequencyPenalty = frequencyPenalty;
|
||||
this.logitBias = logitBias;
|
||||
this.responseFormat =
|
||||
responseFormat == null
|
||||
? null
|
||||
: ResponseFormat.builder()
|
||||
.type(
|
||||
ResponseFormatType.valueOf(
|
||||
responseFormat.toUpperCase(Locale.ROOT)))
|
||||
.build();
|
||||
this.responseFormat = responseFormat == null ? null
|
||||
: ResponseFormat.builder()
|
||||
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
|
||||
.build();
|
||||
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
|
||||
this.seed = seed;
|
||||
this.user = user;
|
||||
@@ -183,61 +150,44 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
return generate(messages, toolSpecifications, null, this.responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
return generate(
|
||||
messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
ToolSpecification toolSpecification) {
|
||||
return generate(messages, singletonList(toolSpecification), toolSpecification,
|
||||
this.responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse chat(ChatRequest request) {
|
||||
Response<AiMessage> response =
|
||||
generate(
|
||||
request.messages(),
|
||||
request.toolSpecifications(),
|
||||
null,
|
||||
generate(request.messages(), request.toolSpecifications(), null,
|
||||
getOrDefault(
|
||||
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
|
||||
this.responseFormat));
|
||||
return ChatResponse.builder()
|
||||
.aiMessage(response.content())
|
||||
.tokenUsage(response.tokenUsage())
|
||||
.finishReason(response.finishReason())
|
||||
.build();
|
||||
return ChatResponse.builder().aiMessage(response.content())
|
||||
.tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build();
|
||||
}
|
||||
|
||||
private Response<AiMessage> generate(
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted,
|
||||
private Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted,
|
||||
ResponseFormat responseFormat) {
|
||||
|
||||
if (responseFormat != null
|
||||
&& responseFormat.type() == JSON_SCHEMA
|
||||
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA
|
||||
&& responseFormat.jsonSchema() == null) {
|
||||
responseFormat = null;
|
||||
}
|
||||
|
||||
ChatCompletionRequest.Builder requestBuilder =
|
||||
ChatCompletionRequest.builder()
|
||||
.model(modelName)
|
||||
.messages(toOpenAiMessages(messages))
|
||||
.topP(topP)
|
||||
.stop(stop)
|
||||
.maxTokens(maxTokens)
|
||||
.presencePenalty(presencePenalty)
|
||||
.frequencyPenalty(frequencyPenalty)
|
||||
.logitBias(logitBias)
|
||||
.responseFormat(responseFormat)
|
||||
.seed(seed)
|
||||
.user(user)
|
||||
.parallelToolCalls(parallelToolCalls);
|
||||
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
|
||||
.model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop)
|
||||
.maxTokens(maxTokens).presencePenalty(presencePenalty)
|
||||
.frequencyPenalty(frequencyPenalty).logitBias(logitBias)
|
||||
.responseFormat(responseFormat).seed(seed).user(user)
|
||||
.parallelToolCalls(parallelToolCalls);
|
||||
|
||||
if (!(baseUrl.contains(ZHIPU))) {
|
||||
requestBuilder.temperature(temperature);
|
||||
@@ -257,40 +207,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||
ChatModelRequestContext requestContext =
|
||||
new ChatModelRequestContext(modelListenerRequest, attributes);
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onRequest(requestContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onRequest(requestContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
ChatCompletionResponse chatCompletionResponse =
|
||||
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
|
||||
|
||||
Response<AiMessage> response =
|
||||
Response.from(
|
||||
aiMessageFrom(chatCompletionResponse),
|
||||
tokenUsageFrom(chatCompletionResponse.usage()),
|
||||
finishReasonFrom(
|
||||
chatCompletionResponse.choices().get(0).finishReason()));
|
||||
Response<AiMessage> response = Response.from(aiMessageFrom(chatCompletionResponse),
|
||||
tokenUsageFrom(chatCompletionResponse.usage()),
|
||||
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()));
|
||||
|
||||
ChatModelResponse modelListenerResponse =
|
||||
createModelListenerResponse(
|
||||
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
|
||||
ChatModelResponseContext responseContext =
|
||||
new ChatModelResponseContext(
|
||||
modelListenerResponse, modelListenerRequest, attributes);
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onResponse(responseContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
ChatModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
|
||||
ChatModelResponseContext responseContext = new ChatModelResponseContext(
|
||||
modelListenerResponse, modelListenerRequest, attributes);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onResponse(responseContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
return response;
|
||||
} catch (RuntimeException e) {
|
||||
@@ -305,14 +248,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
ChatModelErrorContext errorContext =
|
||||
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
|
||||
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onError(errorContext);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Exception while calling model listener", e2);
|
||||
}
|
||||
});
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onError(errorContext);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Exception while calling model listener", e2);
|
||||
}
|
||||
});
|
||||
|
||||
throw e;
|
||||
}
|
||||
@@ -328,8 +270,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
}
|
||||
|
||||
public static OpenAiChatModelBuilder builder() {
|
||||
for (OpenAiChatModelBuilderFactory factory :
|
||||
loadFactories(OpenAiChatModelBuilderFactory.class)) {
|
||||
for (OpenAiChatModelBuilderFactory factory : loadFactories(
|
||||
OpenAiChatModelBuilderFactory.class)) {
|
||||
return factory.get();
|
||||
}
|
||||
return new OpenAiChatModelBuilder();
|
||||
|
||||
Reference in New Issue
Block a user