mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(auth) Optimize the code to support configurable token timeout duration, with a default value set to 2 hours. (#1077)
This commit is contained in:
@@ -22,6 +22,7 @@ import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
private final OpenAiClient client;
|
||||
private final String modelName;
|
||||
private final Double temperature;
|
||||
@@ -34,36 +35,36 @@ public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimat
|
||||
private final Tokenizer tokenizer;
|
||||
|
||||
public FullOpenAiChatModel(String baseUrl, String apiKey, String modelName, Double temperature,
|
||||
Double topP, List<String> stop, Integer maxTokens, Double presencePenalty,
|
||||
Double frequencyPenalty, Duration timeout, Integer maxRetries, Proxy proxy,
|
||||
Boolean logRequests, Boolean logResponses, Tokenizer tokenizer) {
|
||||
baseUrl = (String) Utils.getOrDefault(baseUrl, "https://api.openai.com/v1");
|
||||
Double topP, List<String> stop, Integer maxTokens, Double presencePenalty,
|
||||
Double frequencyPenalty, Duration timeout, Integer maxRetries, Proxy proxy,
|
||||
Boolean logRequests, Boolean logResponses, Tokenizer tokenizer) {
|
||||
baseUrl = Utils.getOrDefault(baseUrl, "https://api.openai.com/v1");
|
||||
if ("demo".equals(apiKey)) {
|
||||
baseUrl = "http://langchain4j.dev/demo/openai/v1";
|
||||
}
|
||||
|
||||
timeout = (Duration) Utils.getOrDefault(timeout, Duration.ofSeconds(60L));
|
||||
timeout = Utils.getOrDefault(timeout, Duration.ofSeconds(60L));
|
||||
this.client = OpenAiClient.builder().openAiApiKey(apiKey)
|
||||
.baseUrl(baseUrl).callTimeout(timeout).connectTimeout(timeout)
|
||||
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
||||
.logRequests(logRequests).logResponses(logResponses).build();
|
||||
this.modelName = (String) Utils.getOrDefault(modelName, "gpt-3.5-turbo");
|
||||
this.temperature = (Double) Utils.getOrDefault(temperature, 0.7D);
|
||||
.baseUrl(baseUrl).callTimeout(timeout).connectTimeout(timeout)
|
||||
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
||||
.logRequests(logRequests).logResponses(logResponses).build();
|
||||
this.modelName = Utils.getOrDefault(modelName, "gpt-3.5-turbo");
|
||||
this.temperature = Utils.getOrDefault(temperature, 0.7D);
|
||||
this.topP = topP;
|
||||
this.stop = stop;
|
||||
this.maxTokens = maxTokens;
|
||||
this.presencePenalty = presencePenalty;
|
||||
this.frequencyPenalty = frequencyPenalty;
|
||||
this.maxRetries = (Integer) Utils.getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = (Tokenizer) Utils.getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
|
||||
this.maxRetries = Utils.getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = Utils.getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
|
||||
}
|
||||
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
return this.generate(messages, (List) null, (ToolSpecification) null);
|
||||
return this.generate(messages, null, null);
|
||||
}
|
||||
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
return this.generate(messages, toolSpecifications, (ToolSpecification) null);
|
||||
return this.generate(messages, toolSpecifications, null);
|
||||
}
|
||||
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
@@ -71,8 +72,8 @@ public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimat
|
||||
}
|
||||
|
||||
private Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted) {
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted) {
|
||||
Builder requestBuilder = null;
|
||||
if (modelName.contains(ChatModel.ZHIPU.toString()) || modelName.contains(ChatModel.ALI.toString())) {
|
||||
requestBuilder = ChatCompletionRequest.builder()
|
||||
@@ -107,15 +108,12 @@ public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimat
|
||||
return this.tokenizer.estimateTokenCountInMessages(messages);
|
||||
}
|
||||
|
||||
public static FullOpenAiChatModel withApiKey(String apiKey) {
|
||||
return builder().apiKey(apiKey).build();
|
||||
}
|
||||
|
||||
public static FullOpenAiChatModel.FullOpenAiChatModelBuilder builder() {
|
||||
return new FullOpenAiChatModel.FullOpenAiChatModelBuilder();
|
||||
}
|
||||
|
||||
public static class FullOpenAiChatModelBuilder {
|
||||
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String modelName;
|
||||
|
||||
@@ -1,32 +1,18 @@
|
||||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
|
||||
import dev.ai4j.openai4j.chat.Function;
|
||||
import dev.ai4j.openai4j.chat.FunctionCall;
|
||||
import dev.ai4j.openai4j.chat.Message;
|
||||
import dev.ai4j.openai4j.chat.Parameters;
|
||||
import dev.ai4j.openai4j.chat.Role;
|
||||
import dev.ai4j.openai4j.shared.Usage;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolParameters;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.ChatModel;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ImproveInternalOpenAiHelper {
|
||||
static final String OPENAI_URL = "https://api.openai.com/v1";
|
||||
static final String OPENAI_DEMO_API_KEY = "demo";
|
||||
static final String OPENAI_DEMO_URL = "http://langchain4j.dev/demo/openai/v1";
|
||||
|
||||
public ImproveInternalOpenAiHelper() {
|
||||
}
|
||||
@@ -77,35 +63,4 @@ public class ImproveInternalOpenAiHelper {
|
||||
}
|
||||
}
|
||||
|
||||
public static List<Function> toFunctions(Collection<ToolSpecification> toolSpecifications) {
|
||||
return (List) toolSpecifications.stream().map(ImproveInternalOpenAiHelper::toFunction)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Function toFunction(ToolSpecification toolSpecification) {
|
||||
return Function.builder().name(toolSpecification.name())
|
||||
.description(toolSpecification.description())
|
||||
.parameters(toOpenAiParameters(toolSpecification.parameters())).build();
|
||||
}
|
||||
|
||||
private static Parameters toOpenAiParameters(ToolParameters toolParameters) {
|
||||
return toolParameters == null ? Parameters.builder().build() : Parameters.builder()
|
||||
.properties(toolParameters.properties()).required(toolParameters.required()).build();
|
||||
}
|
||||
|
||||
public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
|
||||
if (response.content() != null) {
|
||||
return AiMessage.aiMessage(response.content());
|
||||
} else {
|
||||
FunctionCall functionCall = ((ChatCompletionChoice) response.choices().get(0)).message().functionCall();
|
||||
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder()
|
||||
.name(functionCall.name()).arguments(functionCall.arguments()).build();
|
||||
return AiMessage.aiMessage(toolExecutionRequest);
|
||||
}
|
||||
}
|
||||
|
||||
public static TokenUsage tokenUsageFrom(Usage openAiUsage) {
|
||||
return openAiUsage == null ? null : new TokenUsage(openAiUsage.promptTokens(),
|
||||
openAiUsage.completionTokens(), openAiUsage.totalTokens());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user