mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
(improvement)(chat) Reduce frequent loading of embedding models to improve loading performance. (#1478)
This commit is contained in:
@@ -8,6 +8,7 @@ import java.net.URL;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* An embedding model that runs within your Java application's process.
|
||||
@@ -18,36 +19,42 @@ import java.nio.file.Paths;
|
||||
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
||||
*/
|
||||
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
private static volatile OnnxBertBiEncoder cachedModel;
|
||||
private static volatile String cachedModelPath;
|
||||
private static volatile String cachedVocabularyPath;
|
||||
|
||||
private static OnnxBertBiEncoder model = null;
|
||||
|
||||
/**
|
||||
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
|
||||
*/
|
||||
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
||||
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
|
||||
if (StringUtils.isNotBlank(vocabularyPath)) {
|
||||
try {
|
||||
resource = Paths.get(vocabularyPath).toUri().toURL();
|
||||
} catch (MalformedURLException e) {
|
||||
throw new RuntimeException(e);
|
||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||
synchronized (S2OnnxEmbeddingModel.class) {
|
||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
|
||||
if (StringUtils.isNotBlank(vocabularyPath)) {
|
||||
try {
|
||||
resource = Paths.get(vocabularyPath).toUri().toURL();
|
||||
} catch (MalformedURLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
cachedModel = loadFromFileSystem(Paths.get(pathToModel), resource);
|
||||
cachedModelPath = pathToModel;
|
||||
cachedVocabularyPath = vocabularyPath;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (model == null) {
|
||||
model = loadFromFileSystem(Paths.get(pathToModel), resource);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
|
||||
*/
|
||||
public S2OnnxEmbeddingModel(String pathToModel) {
|
||||
this(pathToModel, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OnnxBertBiEncoder model() {
|
||||
return model;
|
||||
return cachedModel;
|
||||
}
|
||||
|
||||
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
|
||||
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|
||||
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
|
||||
}
|
||||
|
||||
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
||||
@@ -61,4 +68,4 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user