package dev.langchain4j.model.embedding; import dev.langchain4j.model.embedding.onnx.AbstractInProcessEmbeddingModel; import dev.langchain4j.model.embedding.onnx.OnnxBertBiEncoder; import dev.langchain4j.model.embedding.onnx.PoolingMode; import org.apache.commons.lang3.StringUtils; import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Objects; import java.util.concurrent.Executors; /** * An embedding model that runs within your Java application's process. Any BERT-based model (e.g., * from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert * models into ONNX format can be found here. Many * models already converted to ONNX format are available * here. Copy from * dev.langchain4j.model.embedding.OnnxEmbeddingModel. */ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { private static volatile OnnxBertBiEncoder cachedModel; private static volatile String cachedModelPath; private static volatile String cachedVocabularyPath; public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) { super(Executors.newSingleThreadExecutor()); if (shouldReloadModel(pathToModel, vocabularyPath)) { synchronized (S2OnnxEmbeddingModel.class) { if (shouldReloadModel(pathToModel, vocabularyPath)) { URL resource = AbstractInProcessEmbeddingModel.class .getResource("/bert-vocabulary-en.txt"); if (StringUtils.isNotBlank(vocabularyPath)) { try { resource = Paths.get(vocabularyPath).toUri().toURL(); } catch (MalformedURLException e) { throw new RuntimeException(e); } } cachedModel = loadFromFileSystem(Paths.get(pathToModel), resource); cachedModelPath = pathToModel; cachedVocabularyPath = vocabularyPath; } } } } public S2OnnxEmbeddingModel(String pathToModel) { this(pathToModel, null); } @Override protected OnnxBertBiEncoder model() { return cachedModel; } private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) { return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel) || !Objects.equals(cachedVocabularyPath, vocabularyPath); } static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) { try { return new OnnxBertBiEncoder(Files.newInputStream(pathToModel), vocabularyFile.openStream(), PoolingMode.MEAN); } catch (IOException e) { throw new RuntimeException(e); } } }