mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
76 lines
3.1 KiB
Java
76 lines
3.1 KiB
Java
package dev.langchain4j.model.embedding;
|
|
|
|
import dev.langchain4j.model.embedding.onnx.AbstractInProcessEmbeddingModel;
|
|
import dev.langchain4j.model.embedding.onnx.OnnxBertBiEncoder;
|
|
import dev.langchain4j.model.embedding.onnx.PoolingMode;
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
|
import java.io.IOException;
|
|
import java.net.MalformedURLException;
|
|
import java.net.URL;
|
|
import java.nio.file.Files;
|
|
import java.nio.file.Path;
|
|
import java.nio.file.Paths;
|
|
import java.util.Objects;
|
|
import java.util.concurrent.Executors;
|
|
|
|
/**
|
|
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
|
|
* from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert
|
|
* models into ONNX format can be found <a href=
|
|
* "https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>. Many
|
|
* models already converted to ONNX format are available
|
|
* <a href="https://huggingface.co/Xenova">here</a>. Copy from
|
|
* dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
|
*/
|
|
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
|
private static volatile OnnxBertBiEncoder cachedModel;
|
|
private static volatile String cachedModelPath;
|
|
private static volatile String cachedVocabularyPath;
|
|
|
|
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
|
super(Executors.newSingleThreadExecutor());
|
|
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
|
synchronized (S2OnnxEmbeddingModel.class) {
|
|
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
|
URL resource = AbstractInProcessEmbeddingModel.class
|
|
.getResource("/bert-vocabulary-en.txt");
|
|
if (StringUtils.isNotBlank(vocabularyPath)) {
|
|
try {
|
|
resource = Paths.get(vocabularyPath).toUri().toURL();
|
|
} catch (MalformedURLException e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
cachedModel = loadFromFileSystem(Paths.get(pathToModel), resource);
|
|
cachedModelPath = pathToModel;
|
|
cachedVocabularyPath = vocabularyPath;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
public S2OnnxEmbeddingModel(String pathToModel) {
|
|
this(pathToModel, null);
|
|
}
|
|
|
|
@Override
|
|
protected OnnxBertBiEncoder model() {
|
|
return cachedModel;
|
|
}
|
|
|
|
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
|
|
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|
|
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
|
|
}
|
|
|
|
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
|
try {
|
|
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel),
|
|
vocabularyFile.openStream(), PoolingMode.MEAN);
|
|
} catch (IOException e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
}
|