From cb139a54e8e1fa010e36799721a69d6404ab6b0c Mon Sep 17 00:00:00 2001 From: zyclove Date: Wed, 12 Feb 2025 20:28:10 +0800 Subject: [PATCH] feat:add openapi supports ApiVersion (#2050) --- .../supersonic/common/pojo/ChatModelConfig.java | 1 + .../supersonic/common/pojo/ChatModelParameters.java | 11 ++++++++++- .../dev/langchain4j/model/openai/OpenAiChatModel.java | 6 ++++-- .../dev/langchain4j/provider/OpenAiModelFactory.java | 2 ++ pom.xml | 2 +- 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java index 56a2bf4a1..b01b75f38 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java @@ -18,6 +18,7 @@ public class ChatModelConfig implements Serializable { private String baseUrl; private String apiKey; private String modelName; + private String apiVersion; private Double temperature = 0.0d; private Long timeOut = 60L; private String endpoint; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelParameters.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelParameters.java index fcf5a2535..5cb277bc8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelParameters.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelParameters.java @@ -34,6 +34,9 @@ public class ChatModelParameters { public static final Parameter CHAT_MODEL_API_KEY = new Parameter("apiKey", "", "ApiKey", "", "password", MODULE_NAME, null, getApiKeyDependency()); + public static final Parameter CHAT_MODEL_API_VERSION = new Parameter("apiVersion", "2024-02-01", "ApiVersion", "", + "string", MODULE_NAME, null, getApiVersionDependency()); + public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b", "Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency()); @@ -51,7 +54,7 @@ public class ChatModelParameters { public static List getParameters() { return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT, - CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, + CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, CHAT_MODEL_API_VERSION, CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT); } @@ -90,6 +93,12 @@ public class ChatModelParameters { ModelProvider.DEMO_CHAT_MODEL.getApiKey())); } + private static List getApiVersionDependency() { + return getDependency(CHAT_MODEL_PROVIDER.getName(), + Lists.newArrayList(OpenAiModelFactory.PROVIDER), + ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_API_VERSION)); + } + private static List getModelNameDependency() { return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(), ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME, diff --git a/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java b/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java index 6bd56fae8..9ac5fe0df 100644 --- a/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java +++ b/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java @@ -70,6 +70,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { private final OpenAiClient client; private final String baseUrl; private final String modelName; + private final String apiVersion; private final Double temperature; private final Double topP; private final List stop; @@ -88,7 +89,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { private final List listeners; @Builder - public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName, + public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName, String apiVersion, Double temperature, Double topP, List stop, Integer maxTokens, Double presencePenalty, Double frequencyPenalty, Map logitBias, String responseFormat, Boolean strictJsonSchema, Integer seed, String user, @@ -104,12 +105,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { timeout = getOrDefault(timeout, ofSeconds(60)); - this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl) + this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl).apiVersion(apiVersion) .organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout) .readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests) .logResponses(logResponses).userAgent(DEFAULT_USER_AGENT) .customHeaders(customHeaders).build(); this.modelName = getOrDefault(modelName, GPT_3_5_TURBO); + this.apiVersion = apiVersion; this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; this.stop = stop; diff --git a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java index 2bd90eb55..a5e815b5e 100644 --- a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java @@ -18,11 +18,13 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean { public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1"; public static final String DEFAULT_MODEL_NAME = "gpt-4o-mini"; public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002"; + public static final String DEFAULT_API_VERSION = "2024-02-01"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl()) .modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt()) + .apiVersion(modelConfig.getApiVersion()) .temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP()) .maxRetries(modelConfig.getMaxRetries()) .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) diff --git a/pom.xml b/pom.xml index 0f707286b..359a90c1d 100644 --- a/pom.xml +++ b/pom.xml @@ -66,7 +66,7 @@ 4.5.1 2.2.6 3.17 - 0.34.0 + 0.35.0 0.27.1 4.0.8