(improvement)(chat) Reduce frequent loading of embedding models to improve loading performance. (#1478)

This commit is contained in:
lexluo09
2024-07-30 11:25:03 +08:00
committed by GitHub
parent 23af977972
commit 9a1fac5d4c
6 changed files with 55 additions and 40 deletions

View File

@@ -4,9 +4,9 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.inmemory.spring.InMemoryAutoConfig;
import dev.langchain4j.provider.AzureModelFactory; import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory; import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.EmbeddingModelConstant;
import dev.langchain4j.provider.InMemoryModelFactory; import dev.langchain4j.provider.InMemoryModelFactory;
import dev.langchain4j.provider.OllamaModelFactory; import dev.langchain4j.provider.OllamaModelFactory;
import dev.langchain4j.provider.OpenAiModelFactory; import dev.langchain4j.provider.OpenAiModelFactory;
@@ -76,7 +76,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_NAME = public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", InMemoryAutoConfig.BGE_SMALL_ZH, new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
"ModelName", "", "ModelName", "",
"string", "向量模型配置", null, "string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(), getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
@@ -90,7 +90,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
ZhipuModelFactory.PROVIDER ZhipuModelFactory.PROVIDER
), ),
ImmutableMap.of( ImmutableMap.of(
InMemoryModelFactory.PROVIDER, InMemoryAutoConfig.BGE_SMALL_ZH, InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002", OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
OllamaModelFactory.PROVIDER, "all-minilm", OllamaModelFactory.PROVIDER, "all-minilm",
AzureModelFactory.PROVIDER, "text-embedding-ada-002", AzureModelFactory.PROVIDER, "text-embedding-ada-002",

View File

@@ -5,6 +5,7 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel; import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel; import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import dev.langchain4j.provider.EmbeddingModelConstant;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -17,10 +18,6 @@ import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
@Configuration @Configuration
@EnableConfigurationProperties(Properties.class) @EnableConfigurationProperties(Properties.class)
public class InMemoryAutoConfig { public class InMemoryAutoConfig {
public static final String BGE_SMALL_ZH = "bge-small-zh";
public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q";
@Bean @Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path") @ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
EmbeddingStoreFactory inMemoryChatModel(Properties properties) { EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
@@ -37,10 +34,10 @@ public class InMemoryAutoConfig {
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath); return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
} }
String modelName = embeddingModelProperties.getModelName(); String modelName = embeddingModelProperties.getModelName();
if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) { if (EmbeddingModelConstant.BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
return new BgeSmallZhEmbeddingModel(); return new BgeSmallZhEmbeddingModel();
} }
if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) { if (EmbeddingModelConstant.ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
return new AllMiniLmL6V2QuantizedEmbeddingModel(); return new AllMiniLmL6V2QuantizedEmbeddingModel();
} }
return new BgeSmallZhEmbeddingModel(); return new BgeSmallZhEmbeddingModel();

View File

@@ -8,6 +8,7 @@ import java.net.URL;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.Objects;
/** /**
* An embedding model that runs within your Java application's process. * An embedding model that runs within your Java application's process.
@@ -18,13 +19,14 @@ import java.nio.file.Paths;
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel. * Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
*/ */
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
private static volatile OnnxBertBiEncoder cachedModel;
private static volatile String cachedModelPath;
private static volatile String cachedVocabularyPath;
private static OnnxBertBiEncoder model = null;
/**
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
*/
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) { public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
synchronized (S2OnnxEmbeddingModel.class) {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt"); URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
if (StringUtils.isNotBlank(vocabularyPath)) { if (StringUtils.isNotBlank(vocabularyPath)) {
try { try {
@@ -33,21 +35,26 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
if (model == null) { cachedModel = loadFromFileSystem(Paths.get(pathToModel), resource);
model = loadFromFileSystem(Paths.get(pathToModel), resource); cachedModelPath = pathToModel;
cachedVocabularyPath = vocabularyPath;
}
}
} }
} }
/**
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
*/
public S2OnnxEmbeddingModel(String pathToModel) { public S2OnnxEmbeddingModel(String pathToModel) {
this(pathToModel, null); this(pathToModel, null);
} }
@Override @Override
protected OnnxBertBiEncoder model() { protected OnnxBertBiEncoder model() {
return model; return cachedModel;
}
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
} }
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) { static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {

View File

@@ -0,0 +1,16 @@
package dev.langchain4j.provider;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.springframework.stereotype.Service;
@Service
public class EmbeddingModelConstant {
public static final String BGE_SMALL_ZH = "bge-small-zh";
public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q";
public static final EmbeddingModel BGE_SMALL_ZH_MODEL = new BgeSmallZhEmbeddingModel();
public static final EmbeddingModel ALL_MINI_LM_L6_V2_MODEL = new AllMiniLmL6V2QuantizedEmbeddingModel();
}

View File

@@ -3,17 +3,12 @@ package dev.langchain4j.provider;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel; import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.ALL_MINILM_L6_V2;
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
@Service @Service
public class InMemoryModelFactory implements ModelFactory, InitializingBean { public class InMemoryModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "IN_MEMORY"; public static final String PROVIDER = "IN_MEMORY";
@@ -31,13 +26,13 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath); return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
} }
String modelName = embeddingModel.getModelName(); String modelName = embeddingModel.getModelName();
if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) { if (EmbeddingModelConstant.BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
return new BgeSmallZhEmbeddingModel(); return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL;
} }
if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) { if (EmbeddingModelConstant.ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
return new AllMiniLmL6V2QuantizedEmbeddingModel(); return EmbeddingModelConstant.ALL_MINI_LM_L6_V2_MODEL;
} }
return new BgeSmallZhEmbeddingModel(); return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL;
} }
@Override @Override

View File

@@ -150,8 +150,8 @@ public class S2SemanticLayerService implements SemanticLayerService {
//2.query from cache //2.query from cache
String cacheKey = queryCache.getCacheKey(queryReq); String cacheKey = queryCache.getCacheKey(queryReq);
log.debug("cacheKey:{}", cacheKey);
Object query = queryCache.query(queryReq, cacheKey); Object query = queryCache.query(queryReq, cacheKey);
log.info("cacheKey:{},query:{}", cacheKey, query);
if (Objects.nonNull(query)) { if (Objects.nonNull(query)) {
SemanticQueryResp queryResp = (SemanticQueryResp) query; SemanticQueryResp queryResp = (SemanticQueryResp) query;
queryResp.setUseCache(true); queryResp.setUseCache(true);