mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-17 16:02:14 +00:00
(improvement)(build) Add spotless during the build process. (#1639)
This commit is contained in:
@@ -11,12 +11,13 @@ import java.nio.file.Paths;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* An embedding model that runs within your Java application's process.
|
||||
* Any BERT-based model (e.g., from HuggingFace) can be used, as long as it is in ONNX format.
|
||||
* Information on how to convert models into ONNX format can be found <a
|
||||
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
|
||||
* from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert
|
||||
* models into ONNX format can be found <a
|
||||
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
|
||||
* Many models already converted to ONNX format are available <a href="https://huggingface.co/Xenova">here</a>.
|
||||
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
||||
* Many models already converted to ONNX format are available <a
|
||||
* href="https://huggingface.co/Xenova">here</a>. Copy from
|
||||
* dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
||||
*/
|
||||
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
private static volatile OnnxBertBiEncoder cachedModel;
|
||||
@@ -27,7 +28,9 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||
synchronized (S2OnnxEmbeddingModel.class) {
|
||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
|
||||
URL resource =
|
||||
AbstractInProcessEmbeddingModel.class.getResource(
|
||||
"/bert-vocabulary-en.txt");
|
||||
if (StringUtils.isNotBlank(vocabularyPath)) {
|
||||
try {
|
||||
resource = Paths.get(vocabularyPath).toUri().toURL();
|
||||
@@ -53,19 +56,17 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
}
|
||||
|
||||
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
|
||||
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|
||||
return cachedModel == null
|
||||
|| !Objects.equals(cachedModelPath, pathToModel)
|
||||
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
|
||||
}
|
||||
|
||||
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
||||
try {
|
||||
return new OnnxBertBiEncoder(
|
||||
Files.newInputStream(pathToModel),
|
||||
vocabularyFile,
|
||||
PoolingMode.MEAN
|
||||
);
|
||||
Files.newInputStream(pathToModel), vocabularyFile, PoolingMode.MEAN);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,8 +44,9 @@ import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
|
||||
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
|
||||
* gpt-4. You can find description of parameters <a
|
||||
* href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
*/
|
||||
@Slf4j
|
||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
@@ -67,31 +68,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
private final Integer maxRetries;
|
||||
private final Tokenizer tokenizer;
|
||||
|
||||
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
|
||||
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>>
|
||||
listeners;
|
||||
|
||||
@Builder
|
||||
public OpenAiChatModel(String baseUrl,
|
||||
String apiKey,
|
||||
String organizationId,
|
||||
String modelName,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
List<String> stop,
|
||||
Integer maxTokens,
|
||||
Double presencePenalty,
|
||||
Double frequencyPenalty,
|
||||
Map<String, Integer> logitBias,
|
||||
String responseFormat,
|
||||
Integer seed,
|
||||
String user,
|
||||
Duration timeout,
|
||||
Integer maxRetries,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders,
|
||||
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
|
||||
public OpenAiChatModel(
|
||||
String baseUrl,
|
||||
String apiKey,
|
||||
String organizationId,
|
||||
String modelName,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
List<String> stop,
|
||||
Integer maxTokens,
|
||||
Double presencePenalty,
|
||||
Double frequencyPenalty,
|
||||
Map<String, Integer> logitBias,
|
||||
String responseFormat,
|
||||
Integer seed,
|
||||
String user,
|
||||
Duration timeout,
|
||||
Integer maxRetries,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders,
|
||||
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
|
||||
|
||||
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
||||
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
||||
@@ -101,20 +104,21 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client = OpenAiClient.builder()
|
||||
.openAiApiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.organizationId(organizationId)
|
||||
.callTimeout(timeout)
|
||||
.connectTimeout(timeout)
|
||||
.readTimeout(timeout)
|
||||
.writeTimeout(timeout)
|
||||
.proxy(proxy)
|
||||
.logRequests(logRequests)
|
||||
.logResponses(logResponses)
|
||||
.userAgent(DEFAULT_USER_AGENT)
|
||||
.customHeaders(customHeaders)
|
||||
.build();
|
||||
this.client =
|
||||
OpenAiClient.builder()
|
||||
.openAiApiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.organizationId(organizationId)
|
||||
.callTimeout(timeout)
|
||||
.connectTimeout(timeout)
|
||||
.readTimeout(timeout)
|
||||
.writeTimeout(timeout)
|
||||
.proxy(proxy)
|
||||
.logRequests(logRequests)
|
||||
.logResponses(logResponses)
|
||||
.userAgent(DEFAULT_USER_AGENT)
|
||||
.customHeaders(customHeaders)
|
||||
.build();
|
||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
@@ -141,31 +145,34 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
return generate(messages, toolSpecifications, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
return generate(messages, singletonList(toolSpecification), toolSpecification);
|
||||
}
|
||||
|
||||
private Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted
|
||||
) {
|
||||
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
|
||||
.model(modelName)
|
||||
.messages(toOpenAiMessages(messages))
|
||||
.topP(topP)
|
||||
.stop(stop)
|
||||
.maxTokens(maxTokens)
|
||||
.presencePenalty(presencePenalty)
|
||||
.frequencyPenalty(frequencyPenalty)
|
||||
.logitBias(logitBias)
|
||||
.responseFormat(responseFormat)
|
||||
.seed(seed)
|
||||
.user(user);
|
||||
private Response<AiMessage> generate(
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted) {
|
||||
ChatCompletionRequest.Builder requestBuilder =
|
||||
ChatCompletionRequest.builder()
|
||||
.model(modelName)
|
||||
.messages(toOpenAiMessages(messages))
|
||||
.topP(topP)
|
||||
.stop(stop)
|
||||
.maxTokens(maxTokens)
|
||||
.presencePenalty(presencePenalty)
|
||||
.frequencyPenalty(frequencyPenalty)
|
||||
.logitBias(logitBias)
|
||||
.responseFormat(responseFormat)
|
||||
.seed(seed)
|
||||
.user(user);
|
||||
if (!(baseUrl.contains(ZHIPU))) {
|
||||
requestBuilder.temperature(temperature);
|
||||
}
|
||||
@@ -181,36 +188,37 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
ChatLanguageModelRequest modelListenerRequest =
|
||||
createModelListenerRequest(request, messages, toolSpecifications);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onRequest(modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onRequest(modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
ChatCompletionResponse chatCompletionResponse =
|
||||
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
|
||||
|
||||
Response<AiMessage> response = Response.from(
|
||||
aiMessageFrom(chatCompletionResponse),
|
||||
tokenUsageFrom(chatCompletionResponse.usage()),
|
||||
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason())
|
||||
);
|
||||
Response<AiMessage> response =
|
||||
Response.from(
|
||||
aiMessageFrom(chatCompletionResponse),
|
||||
tokenUsageFrom(chatCompletionResponse.usage()),
|
||||
finishReasonFrom(
|
||||
chatCompletionResponse.choices().get(0).finishReason()));
|
||||
|
||||
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
chatCompletionResponse.id(),
|
||||
chatCompletionResponse.model(),
|
||||
response
|
||||
);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onResponse(modelListenerResponse, modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
ChatLanguageModelResponse modelListenerResponse =
|
||||
createModelListenerResponse(
|
||||
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onResponse(modelListenerResponse, modelListenerRequest);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
return response;
|
||||
} catch (RuntimeException e) {
|
||||
@@ -222,13 +230,14 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
error = e;
|
||||
}
|
||||
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onError(error, null, modelListenerRequest);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Exception while calling model listener", e2);
|
||||
}
|
||||
});
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onError(error, null, modelListenerRequest);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Exception while calling model listener", e2);
|
||||
}
|
||||
});
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
@@ -243,7 +252,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
}
|
||||
|
||||
public static OpenAiChatModelBuilder builder() {
|
||||
for (OpenAiChatModelBuilderFactory factory : loadFactories(OpenAiChatModelBuilderFactory.class)) {
|
||||
for (OpenAiChatModelBuilderFactory factory :
|
||||
loadFactories(OpenAiChatModelBuilderFactory.class)) {
|
||||
return factory.get();
|
||||
}
|
||||
return new OpenAiChatModelBuilder();
|
||||
@@ -261,4 +271,4 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,4 +15,4 @@ public enum ChatCompletionModel {
|
||||
public String toString() {
|
||||
return this.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,9 @@ import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and glm-4.
|
||||
* You can find description of parameters <a href="https://open.bigmodel.cn/dev/api">here</a>.
|
||||
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
|
||||
* glm-4. You can find description of parameters <a
|
||||
* href="https://open.bigmodel.cn/dev/api">here</a>.
|
||||
*/
|
||||
public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
|
||||
@@ -49,24 +50,25 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
Integer maxRetries,
|
||||
Integer maxToken,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses
|
||||
) {
|
||||
Boolean logResponses) {
|
||||
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.maxToken = getOrDefault(maxToken, 512);
|
||||
this.client = ZhipuAiClient.builder()
|
||||
.baseUrl(this.baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.logRequests(getOrDefault(logRequests, false))
|
||||
.logResponses(getOrDefault(logResponses, false))
|
||||
.build();
|
||||
this.client =
|
||||
ZhipuAiClient.builder()
|
||||
.baseUrl(this.baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.logRequests(getOrDefault(logRequests, false))
|
||||
.logResponses(getOrDefault(logResponses, false))
|
||||
.build();
|
||||
}
|
||||
|
||||
public static ZhipuAiChatModelBuilder builder() {
|
||||
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
|
||||
for (ZhipuAiChatModelBuilderFactory factories :
|
||||
loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
|
||||
return factories.get();
|
||||
}
|
||||
return new ZhipuAiChatModelBuilder();
|
||||
@@ -78,36 +80,36 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
|
||||
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
|
||||
.model(this.model)
|
||||
.maxTokens(maxToken)
|
||||
.stream(false)
|
||||
.topP(topP)
|
||||
.toolChoice(AUTO)
|
||||
.messages(toZhipuAiMessages(messages));
|
||||
ChatCompletionRequest.Builder requestBuilder =
|
||||
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
|
||||
.topP(topP)
|
||||
.toolChoice(AUTO)
|
||||
.messages(toZhipuAiMessages(messages));
|
||||
|
||||
if (!isNullOrEmpty(toolSpecifications)) {
|
||||
requestBuilder.tools(toTools(toolSpecifications));
|
||||
}
|
||||
|
||||
ChatCompletionResponse response = withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
|
||||
ChatCompletionResponse response =
|
||||
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
|
||||
return Response.from(
|
||||
aiMessageFrom(response),
|
||||
tokenUsageFrom(response.getUsage()),
|
||||
finishReasonFrom(response.getChoices().get(0).getFinishReason())
|
||||
);
|
||||
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
return generate(messages, toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
return generate(
|
||||
messages, toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
}
|
||||
|
||||
public static class ZhipuAiChatModelBuilder {
|
||||
public ZhipuAiChatModelBuilder() {
|
||||
}
|
||||
public ZhipuAiChatModelBuilder() {}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user