mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
(improvement)(chat) Reduce frequent loading of embedding models to improve loading performance. (#1478)
This commit is contained in:
@@ -4,9 +4,9 @@ import com.google.common.collect.ImmutableMap;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Parameter;
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
import dev.langchain4j.inmemory.spring.InMemoryAutoConfig;
|
|
||||||
import dev.langchain4j.provider.AzureModelFactory;
|
import dev.langchain4j.provider.AzureModelFactory;
|
||||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||||
|
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||||
import dev.langchain4j.provider.OllamaModelFactory;
|
import dev.langchain4j.provider.OllamaModelFactory;
|
||||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
@@ -76,7 +76,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MODEL_NAME =
|
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||||
new Parameter("s2.embedding.model.name", InMemoryAutoConfig.BGE_SMALL_ZH,
|
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||||
"ModelName", "",
|
"ModelName", "",
|
||||||
"string", "向量模型配置", null,
|
"string", "向量模型配置", null,
|
||||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
@@ -90,7 +90,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
|||||||
ZhipuModelFactory.PROVIDER
|
ZhipuModelFactory.PROVIDER
|
||||||
),
|
),
|
||||||
ImmutableMap.of(
|
ImmutableMap.of(
|
||||||
InMemoryModelFactory.PROVIDER, InMemoryAutoConfig.BGE_SMALL_ZH,
|
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||||
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
|
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||||
OllamaModelFactory.PROVIDER, "all-minilm",
|
OllamaModelFactory.PROVIDER, "all-minilm",
|
||||||
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
|
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
|||||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||||
|
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||||
@@ -17,10 +18,6 @@ import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
|
|||||||
@Configuration
|
@Configuration
|
||||||
@EnableConfigurationProperties(Properties.class)
|
@EnableConfigurationProperties(Properties.class)
|
||||||
public class InMemoryAutoConfig {
|
public class InMemoryAutoConfig {
|
||||||
|
|
||||||
public static final String BGE_SMALL_ZH = "bge-small-zh";
|
|
||||||
public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q";
|
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
|
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
|
||||||
EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
|
EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
|
||||||
@@ -37,10 +34,10 @@ public class InMemoryAutoConfig {
|
|||||||
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
|
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
|
||||||
}
|
}
|
||||||
String modelName = embeddingModelProperties.getModelName();
|
String modelName = embeddingModelProperties.getModelName();
|
||||||
if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
|
if (EmbeddingModelConstant.BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
|
||||||
return new BgeSmallZhEmbeddingModel();
|
return new BgeSmallZhEmbeddingModel();
|
||||||
}
|
}
|
||||||
if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
|
if (EmbeddingModelConstant.ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
|
||||||
return new AllMiniLmL6V2QuantizedEmbeddingModel();
|
return new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||||
}
|
}
|
||||||
return new BgeSmallZhEmbeddingModel();
|
return new BgeSmallZhEmbeddingModel();
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import java.net.URL;
|
|||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An embedding model that runs within your Java application's process.
|
* An embedding model that runs within your Java application's process.
|
||||||
@@ -18,36 +19,42 @@ import java.nio.file.Paths;
|
|||||||
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
||||||
*/
|
*/
|
||||||
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||||
|
private static volatile OnnxBertBiEncoder cachedModel;
|
||||||
|
private static volatile String cachedModelPath;
|
||||||
|
private static volatile String cachedVocabularyPath;
|
||||||
|
|
||||||
private static OnnxBertBiEncoder model = null;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
|
|
||||||
*/
|
|
||||||
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
||||||
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
|
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||||
if (StringUtils.isNotBlank(vocabularyPath)) {
|
synchronized (S2OnnxEmbeddingModel.class) {
|
||||||
try {
|
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||||
resource = Paths.get(vocabularyPath).toUri().toURL();
|
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
|
||||||
} catch (MalformedURLException e) {
|
if (StringUtils.isNotBlank(vocabularyPath)) {
|
||||||
throw new RuntimeException(e);
|
try {
|
||||||
|
resource = Paths.get(vocabularyPath).toUri().toURL();
|
||||||
|
} catch (MalformedURLException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cachedModel = loadFromFileSystem(Paths.get(pathToModel), resource);
|
||||||
|
cachedModelPath = pathToModel;
|
||||||
|
cachedVocabularyPath = vocabularyPath;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (model == null) {
|
|
||||||
model = loadFromFileSystem(Paths.get(pathToModel), resource);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
|
|
||||||
*/
|
|
||||||
public S2OnnxEmbeddingModel(String pathToModel) {
|
public S2OnnxEmbeddingModel(String pathToModel) {
|
||||||
this(pathToModel, null);
|
this(pathToModel, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected OnnxBertBiEncoder model() {
|
protected OnnxBertBiEncoder model() {
|
||||||
return model;
|
return cachedModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
|
||||||
|
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|
||||||
|
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
||||||
@@ -61,4 +68,4 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
|||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
public class EmbeddingModelConstant {
|
||||||
|
|
||||||
|
public static final String BGE_SMALL_ZH = "bge-small-zh";
|
||||||
|
public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q";
|
||||||
|
public static final EmbeddingModel BGE_SMALL_ZH_MODEL = new BgeSmallZhEmbeddingModel();
|
||||||
|
public static final EmbeddingModel ALL_MINI_LM_L6_V2_MODEL = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||||
|
|
||||||
|
}
|
||||||
@@ -3,17 +3,12 @@ package dev.langchain4j.provider;
|
|||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
import org.springframework.beans.factory.InitializingBean;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.ALL_MINILM_L6_V2;
|
|
||||||
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||||
public static final String PROVIDER = "IN_MEMORY";
|
public static final String PROVIDER = "IN_MEMORY";
|
||||||
@@ -31,13 +26,13 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
|||||||
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
|
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
|
||||||
}
|
}
|
||||||
String modelName = embeddingModel.getModelName();
|
String modelName = embeddingModel.getModelName();
|
||||||
if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
|
if (EmbeddingModelConstant.BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
|
||||||
return new BgeSmallZhEmbeddingModel();
|
return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL;
|
||||||
}
|
}
|
||||||
if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
|
if (EmbeddingModelConstant.ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
|
||||||
return new AllMiniLmL6V2QuantizedEmbeddingModel();
|
return EmbeddingModelConstant.ALL_MINI_LM_L6_V2_MODEL;
|
||||||
}
|
}
|
||||||
return new BgeSmallZhEmbeddingModel();
|
return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -150,8 +150,8 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
|||||||
//2.query from cache
|
//2.query from cache
|
||||||
|
|
||||||
String cacheKey = queryCache.getCacheKey(queryReq);
|
String cacheKey = queryCache.getCacheKey(queryReq);
|
||||||
log.debug("cacheKey:{}", cacheKey);
|
|
||||||
Object query = queryCache.query(queryReq, cacheKey);
|
Object query = queryCache.query(queryReq, cacheKey);
|
||||||
|
log.info("cacheKey:{},query:{}", cacheKey, query);
|
||||||
if (Objects.nonNull(query)) {
|
if (Objects.nonNull(query)) {
|
||||||
SemanticQueryResp queryResp = (SemanticQueryResp) query;
|
SemanticQueryResp queryResp = (SemanticQueryResp) query;
|
||||||
queryResp.setUseCache(true);
|
queryResp.setUseCache(true);
|
||||||
@@ -495,7 +495,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private SemanticQueryResp getQueryResultWithSchemaResp(EntityInfo entityInfo,
|
private SemanticQueryResp getQueryResultWithSchemaResp(EntityInfo entityInfo,
|
||||||
DataSetSchema dataSetSchema, User user) {
|
DataSetSchema dataSetSchema, User user) {
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
|
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
|
||||||
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||||
|
|||||||
Reference in New Issue
Block a user