(improvement)[build] Use Spotless to customize the code formatting (#1750)

This commit is contained in:
lexluo09
2024-10-04 00:05:04 +08:00
committed by GitHub
parent 44d1cde34f
commit 71a9954be5
521 changed files with 7811 additions and 13046 deletions

View File

@@ -13,10 +13,10 @@ import java.util.Objects;
/**
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
* from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert
* models into ONNX format can be found <a
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
* Many models already converted to ONNX format are available <a
* href="https://huggingface.co/Xenova">here</a>. Copy from
* models into ONNX format can be found <a href=
* "https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>. Many
* models already converted to ONNX format are available
* <a href="https://huggingface.co/Xenova">here</a>. Copy from
* dev.langchain4j.model.embedding.OnnxEmbeddingModel.
*/
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
@@ -28,9 +28,8 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
synchronized (S2OnnxEmbeddingModel.class) {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
URL resource =
AbstractInProcessEmbeddingModel.class.getResource(
"/bert-vocabulary-en.txt");
URL resource = AbstractInProcessEmbeddingModel.class
.getResource("/bert-vocabulary-en.txt");
if (StringUtils.isNotBlank(vocabularyPath)) {
try {
resource = Paths.get(vocabularyPath).toUri().toURL();
@@ -56,15 +55,14 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
}
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
return cachedModel == null
|| !Objects.equals(cachedModelPath, pathToModel)
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
}
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
try {
return new OnnxBertBiEncoder(
Files.newInputStream(pathToModel), vocabularyFile, PoolingMode.MEAN);
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel), vocabularyFile,
PoolingMode.MEAN);
} catch (IOException e) {
throw new RuntimeException(e);
}

View File

@@ -60,8 +60,8 @@ import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
* gpt-4. You can find description of parameters <a
* href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
* gpt-4. You can find description of parameters
* <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
@@ -88,32 +88,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final List<ChatModelListener> listeners;
@Builder
public OpenAiChatModel(
String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
List<String> stop,
Integer maxTokens,
Double presencePenalty,
Double frequencyPenalty,
Map<String, Integer> logitBias,
String responseFormat,
Boolean strictJsonSchema,
Integer seed,
String user,
Boolean strictTools,
Boolean parallelToolCalls,
Duration timeout,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ChatModelListener> listeners) {
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName,
Double temperature, Double topP, List<String> stop, Integer maxTokens,
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Integer maxRetries,
Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer,
Map<String, String> customHeaders, List<ChatModelListener> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
@@ -123,21 +104,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
timeout = getOrDefault(timeout, ofSeconds(60));
this.client =
OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.proxy(proxy)
.logRequests(logRequests)
.logResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders)
.build();
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl)
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders).build();
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
@@ -146,14 +117,10 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.logitBias = logitBias;
this.responseFormat =
responseFormat == null
? null
: ResponseFormat.builder()
.type(
ResponseFormatType.valueOf(
responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.responseFormat = responseFormat == null ? null
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
this.seed = seed;
this.user = user;
@@ -183,61 +150,44 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(
messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification,
this.responseFormat);
}
@Override
public ChatResponse chat(ChatRequest request) {
Response<AiMessage> response =
generate(
request.messages(),
request.toolSpecifications(),
null,
generate(request.messages(), request.toolSpecifications(), null,
getOrDefault(
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
this.responseFormat));
return ChatResponse.builder()
.aiMessage(response.content())
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.build();
return ChatResponse.builder().aiMessage(response.content())
.tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build();
}
private Response<AiMessage> generate(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted,
private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted,
ResponseFormat responseFormat) {
if (responseFormat != null
&& responseFormat.type() == JSON_SCHEMA
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA
&& responseFormat.jsonSchema() == null) {
responseFormat = null;
}
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
.topP(topP)
.stop(stop)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user)
.parallelToolCalls(parallelToolCalls);
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop)
.maxTokens(maxTokens).presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty).logitBias(logitBias)
.responseFormat(responseFormat).seed(seed).user(user)
.parallelToolCalls(parallelToolCalls);
if (!(baseUrl.contains(ZHIPU))) {
requestBuilder.temperature(temperature);
@@ -257,40 +207,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext =
new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
try {
ChatCompletionResponse chatCompletionResponse =
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
Response<AiMessage> response =
Response.from(
aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(
chatCompletionResponse.choices().get(0).finishReason()));
Response<AiMessage> response = Response.from(aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()));
ChatModelResponse modelListenerResponse =
createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext =
new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
ChatModelResponse modelListenerResponse = createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return response;
} catch (RuntimeException e) {
@@ -305,14 +248,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
ChatModelErrorContext errorContext =
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
listeners.forEach(
listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
@@ -328,8 +270,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
public static OpenAiChatModelBuilder builder() {
for (OpenAiChatModelBuilderFactory factory :
loadFactories(OpenAiChatModelBuilderFactory.class)) {
for (OpenAiChatModelBuilderFactory factory : loadFactories(
OpenAiChatModelBuilderFactory.class)) {
return factory.get();
}
return new OpenAiChatModelBuilder();

View File

@@ -3,9 +3,8 @@ package dev.langchain4j.model.openai;
public enum OpenAiChatModelName {
GPT_3_5_TURBO("gpt-3.5-turbo"), // alias
@Deprecated
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"),
GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106"),
GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"), GPT_3_5_TURBO_1106(
"gpt-3.5-turbo-1106"), GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k"), // alias
@Deprecated
@@ -13,22 +12,18 @@ public enum OpenAiChatModelName {
GPT_4("gpt-4"), // alias
@Deprecated
GPT_4_0314("gpt-4-0314"),
GPT_4_0613("gpt-4-0613"),
GPT_4_0314("gpt-4-0314"), GPT_4_0613("gpt-4-0613"),
GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"), // alias
GPT_4_1106_PREVIEW("gpt-4-1106-preview"),
GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
GPT_4_1106_PREVIEW("gpt-4-1106-preview"), GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
GPT_4_32K("gpt-4-32k"), // alias
GPT_4_32K_0314("gpt-4-32k-0314"),
GPT_4_32K_0613("gpt-4-32k-0613"),
GPT_4_32K_0314("gpt-4-32k-0314"), GPT_4_32K_0613("gpt-4-32k-0613"),
@Deprecated
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
GPT_4_O("gpt-4o"),
GPT_4_O_MINI("gpt-4o-mini");
GPT_4_O("gpt-4o"), GPT_4_O_MINI("gpt-4o-mini");
private final String stringValue;

View File

@@ -1,9 +1,7 @@
package dev.langchain4j.model.zhipu;
public enum ChatCompletionModel {
GLM_4("glm-4"),
GLM_3_TURBO("glm-3-turbo"),
CHATGLM_TURBO("chatglm_turbo");
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
private final String value;

View File

@@ -27,8 +27,8 @@ import static java.util.Collections.singletonList;
/**
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
* glm-4. You can find description of parameters <a
* href="https://open.bigmodel.cn/dev/api">here</a>.
* glm-4. You can find description of parameters
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
*/
public class ZhipuAiChatModel implements ChatLanguageModel {
@@ -41,15 +41,8 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
private final ZhipuAiClient client;
@Builder
public ZhipuAiChatModel(
String baseUrl,
String apiKey,
Double temperature,
Double topP,
String model,
Integer maxRetries,
Integer maxToken,
Boolean logRequests,
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
Boolean logResponses) {
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
this.temperature = getOrDefault(temperature, 0.7);
@@ -57,18 +50,14 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
this.maxRetries = getOrDefault(maxRetries, 3);
this.maxToken = getOrDefault(maxToken, 512);
this.client =
ZhipuAiClient.builder()
.baseUrl(this.baseUrl)
.apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
.build();
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false)).build();
}
public static ZhipuAiChatModelBuilder builder() {
for (ZhipuAiChatModelBuilderFactory factories :
loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
ZhipuAiChatModelBuilderFactory.class)) {
return factories.get();
}
return new ZhipuAiChatModelBuilder();
@@ -80,15 +69,13 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
.topP(topP)
.toolChoice(AUTO)
.messages(toZhipuAiMessages(messages));
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
@@ -96,17 +83,15 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
ChatCompletionResponse response =
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
return Response.from(
aiMessageFrom(response),
tokenUsageFrom(response.getUsage()),
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(
messages, toolSpecification != null ? singletonList(toolSpecification) : null);
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages,
toolSpecification != null ? singletonList(toolSpecification) : null);
}
public static class ZhipuAiChatModelBuilder {