feat:add openapi supports ApiVersion (#2050)

This commit is contained in:
zyclove
2025-02-12 20:28:10 +08:00
committed by GitHub
parent f412ae4539
commit cb139a54e8
5 changed files with 18 additions and 4 deletions

View File

@@ -18,6 +18,7 @@ public class ChatModelConfig implements Serializable {
private String baseUrl; private String baseUrl;
private String apiKey; private String apiKey;
private String modelName; private String modelName;
private String apiVersion;
private Double temperature = 0.0d; private Double temperature = 0.0d;
private Long timeOut = 60L; private Long timeOut = 60L;
private String endpoint; private String endpoint;

View File

@@ -34,6 +34,9 @@ public class ChatModelParameters {
public static final Parameter CHAT_MODEL_API_KEY = new Parameter("apiKey", "", "ApiKey", "", public static final Parameter CHAT_MODEL_API_KEY = new Parameter("apiKey", "", "ApiKey", "",
"password", MODULE_NAME, null, getApiKeyDependency()); "password", MODULE_NAME, null, getApiKeyDependency());
public static final Parameter CHAT_MODEL_API_VERSION = new Parameter("apiVersion", "2024-02-01", "ApiVersion", "",
"string", MODULE_NAME, null, getApiVersionDependency());
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b", public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b",
"Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency()); "Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency());
@@ -51,7 +54,7 @@ public class ChatModelParameters {
public static List<Parameter> getParameters() { public static List<Parameter> getParameters() {
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT, return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, CHAT_MODEL_API_VERSION,
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT); CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
} }
@@ -90,6 +93,12 @@ public class ChatModelParameters {
ModelProvider.DEMO_CHAT_MODEL.getApiKey())); ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
} }
private static List<Parameter.Dependency> getApiVersionDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_API_VERSION));
}
private static List<Parameter.Dependency> getModelNameDependency() { private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(), return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME, ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,

View File

@@ -70,6 +70,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final OpenAiClient client; private final OpenAiClient client;
private final String baseUrl; private final String baseUrl;
private final String modelName; private final String modelName;
private final String apiVersion;
private final Double temperature; private final Double temperature;
private final Double topP; private final Double topP;
private final List<String> stop; private final List<String> stop;
@@ -88,7 +89,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final List<ChatModelListener> listeners; private final List<ChatModelListener> listeners;
@Builder @Builder
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName, public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName, String apiVersion,
Double temperature, Double topP, List<String> stop, Integer maxTokens, Double temperature, Double topP, List<String> stop, Integer maxTokens,
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias, Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
String responseFormat, Boolean strictJsonSchema, Integer seed, String user, String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
@@ -104,12 +105,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
timeout = getOrDefault(timeout, ofSeconds(60)); timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl) this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl).apiVersion(apiVersion)
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout) .organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests) .readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT) .logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders).build(); .customHeaders(customHeaders).build();
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO); this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.apiVersion = apiVersion;
this.temperature = getOrDefault(temperature, 0.7); this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP; this.topP = topP;
this.stop = stop; this.stop = stop;

View File

@@ -18,11 +18,13 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1"; public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
public static final String DEFAULT_MODEL_NAME = "gpt-4o-mini"; public static final String DEFAULT_MODEL_NAME = "gpt-4o-mini";
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002"; public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
public static final String DEFAULT_API_VERSION = "2024-02-01";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl()) return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt()) .modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
.apiVersion(modelConfig.getApiVersion())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP()) .temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()) .maxRetries(modelConfig.getMaxRetries())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))

View File

@@ -66,7 +66,7 @@
<mockito-inline.version>4.5.1</mockito-inline.version> <mockito-inline.version>4.5.1</mockito-inline.version>
<easyexcel.version>2.2.6</easyexcel.version> <easyexcel.version>2.2.6</easyexcel.version>
<poi.version>3.17</poi.version> <poi.version>3.17</poi.version>
<langchain4j.version>0.34.0</langchain4j.version> <langchain4j.version>0.35.0</langchain4j.version>
<langchain4j.embedding.version>0.27.1</langchain4j.embedding.version> <langchain4j.embedding.version>0.27.1</langchain4j.embedding.version>
<!-- <postgresql.version>42.7.1</postgresql.version>--> <!-- <postgresql.version>42.7.1</postgresql.version>-->
<st.version>4.0.8</st.version> <st.version>4.0.8</st.version>