(improvement)(auth) Optimize the code to support configurable token timeout duration, with a default value set to 2 hours. (#1077)

This commit is contained in:
lexluo09
2024-06-02 00:08:24 +08:00
committed by GitHub
parent 2da0eb126a
commit 78d8e652cd
6 changed files with 73 additions and 233 deletions

View File

@@ -22,6 +22,7 @@ import java.util.Collections;
import java.util.List;
public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final OpenAiClient client;
private final String modelName;
private final Double temperature;
@@ -34,36 +35,36 @@ public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimat
private final Tokenizer tokenizer;
public FullOpenAiChatModel(String baseUrl, String apiKey, String modelName, Double temperature,
Double topP, List<String> stop, Integer maxTokens, Double presencePenalty,
Double frequencyPenalty, Duration timeout, Integer maxRetries, Proxy proxy,
Boolean logRequests, Boolean logResponses, Tokenizer tokenizer) {
baseUrl = (String) Utils.getOrDefault(baseUrl, "https://api.openai.com/v1");
Double topP, List<String> stop, Integer maxTokens, Double presencePenalty,
Double frequencyPenalty, Duration timeout, Integer maxRetries, Proxy proxy,
Boolean logRequests, Boolean logResponses, Tokenizer tokenizer) {
baseUrl = Utils.getOrDefault(baseUrl, "https://api.openai.com/v1");
if ("demo".equals(apiKey)) {
baseUrl = "http://langchain4j.dev/demo/openai/v1";
}
timeout = (Duration) Utils.getOrDefault(timeout, Duration.ofSeconds(60L));
timeout = Utils.getOrDefault(timeout, Duration.ofSeconds(60L));
this.client = OpenAiClient.builder().openAiApiKey(apiKey)
.baseUrl(baseUrl).callTimeout(timeout).connectTimeout(timeout)
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
.logRequests(logRequests).logResponses(logResponses).build();
this.modelName = (String) Utils.getOrDefault(modelName, "gpt-3.5-turbo");
this.temperature = (Double) Utils.getOrDefault(temperature, 0.7D);
.baseUrl(baseUrl).callTimeout(timeout).connectTimeout(timeout)
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
.logRequests(logRequests).logResponses(logResponses).build();
this.modelName = Utils.getOrDefault(modelName, "gpt-3.5-turbo");
this.temperature = Utils.getOrDefault(temperature, 0.7D);
this.topP = topP;
this.stop = stop;
this.maxTokens = maxTokens;
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.maxRetries = (Integer) Utils.getOrDefault(maxRetries, 3);
this.tokenizer = (Tokenizer) Utils.getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
this.maxRetries = Utils.getOrDefault(maxRetries, 3);
this.tokenizer = Utils.getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
}
public Response<AiMessage> generate(List<ChatMessage> messages) {
return this.generate(messages, (List) null, (ToolSpecification) null);
return this.generate(messages, null, null);
}
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return this.generate(messages, toolSpecifications, (ToolSpecification) null);
return this.generate(messages, toolSpecifications, null);
}
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
@@ -71,8 +72,8 @@ public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimat
}
private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
Builder requestBuilder = null;
if (modelName.contains(ChatModel.ZHIPU.toString()) || modelName.contains(ChatModel.ALI.toString())) {
requestBuilder = ChatCompletionRequest.builder()
@@ -107,15 +108,12 @@ public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimat
return this.tokenizer.estimateTokenCountInMessages(messages);
}
public static FullOpenAiChatModel withApiKey(String apiKey) {
return builder().apiKey(apiKey).build();
}
public static FullOpenAiChatModel.FullOpenAiChatModelBuilder builder() {
return new FullOpenAiChatModel.FullOpenAiChatModelBuilder();
}
public static class FullOpenAiChatModelBuilder {
private String baseUrl;
private String apiKey;
private String modelName;

View File

@@ -1,32 +1,18 @@
package dev.langchain4j.model.openai;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Function;
import dev.ai4j.openai4j.chat.FunctionCall;
import dev.ai4j.openai4j.chat.Message;
import dev.ai4j.openai4j.chat.Parameters;
import dev.ai4j.openai4j.chat.Role;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.ChatModel;
import dev.langchain4j.model.output.TokenUsage;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
public class ImproveInternalOpenAiHelper {
static final String OPENAI_URL = "https://api.openai.com/v1";
static final String OPENAI_DEMO_API_KEY = "demo";
static final String OPENAI_DEMO_URL = "http://langchain4j.dev/demo/openai/v1";
public ImproveInternalOpenAiHelper() {
}
@@ -77,35 +63,4 @@ public class ImproveInternalOpenAiHelper {
}
}
public static List<Function> toFunctions(Collection<ToolSpecification> toolSpecifications) {
return (List) toolSpecifications.stream().map(ImproveInternalOpenAiHelper::toFunction)
.collect(Collectors.toList());
}
private static Function toFunction(ToolSpecification toolSpecification) {
return Function.builder().name(toolSpecification.name())
.description(toolSpecification.description())
.parameters(toOpenAiParameters(toolSpecification.parameters())).build();
}
private static Parameters toOpenAiParameters(ToolParameters toolParameters) {
return toolParameters == null ? Parameters.builder().build() : Parameters.builder()
.properties(toolParameters.properties()).required(toolParameters.required()).build();
}
public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
if (response.content() != null) {
return AiMessage.aiMessage(response.content());
} else {
FunctionCall functionCall = ((ChatCompletionChoice) response.choices().get(0)).message().functionCall();
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder()
.name(functionCall.name()).arguments(functionCall.arguments()).build();
return AiMessage.aiMessage(toolExecutionRequest);
}
}
public static TokenUsage tokenUsageFrom(Usage openAiUsage) {
return openAiUsage == null ? null : new TokenUsage(openAiUsage.promptTokens(),
openAiUsage.completionTokens(), openAiUsage.totalTokens());
}
}