(improvement)(chat) Reduce frequent loading of embedding models to improve loading performance. (#1478)

This commit is contained in:
lexluo09
2024-07-30 11:25:03 +08:00
committed by GitHub
parent 23af977972
commit 9a1fac5d4c
6 changed files with 55 additions and 40 deletions

View File

@@ -8,6 +8,7 @@ import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
/**
* An embedding model that runs within your Java application's process.
@@ -18,36 +19,42 @@ import java.nio.file.Paths;
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
*/
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
private static volatile OnnxBertBiEncoder cachedModel;
private static volatile String cachedModelPath;
private static volatile String cachedVocabularyPath;
private static OnnxBertBiEncoder model = null;
/**
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
*/
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
if (StringUtils.isNotBlank(vocabularyPath)) {
try {
resource = Paths.get(vocabularyPath).toUri().toURL();
} catch (MalformedURLException e) {
throw new RuntimeException(e);
if (shouldReloadModel(pathToModel, vocabularyPath)) {
synchronized (S2OnnxEmbeddingModel.class) {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
if (StringUtils.isNotBlank(vocabularyPath)) {
try {
resource = Paths.get(vocabularyPath).toUri().toURL();
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
cachedModel = loadFromFileSystem(Paths.get(pathToModel), resource);
cachedModelPath = pathToModel;
cachedVocabularyPath = vocabularyPath;
}
}
}
if (model == null) {
model = loadFromFileSystem(Paths.get(pathToModel), resource);
}
}
/**
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
*/
public S2OnnxEmbeddingModel(String pathToModel) {
this(pathToModel, null);
}
@Override
protected OnnxBertBiEncoder model() {
return model;
return cachedModel;
}
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
}
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
@@ -61,4 +68,4 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
throw new RuntimeException(e);
}
}
}
}