diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java index ff3695052..c40cfb1e6 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java @@ -4,9 +4,9 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.Parameter; -import dev.langchain4j.inmemory.spring.InMemoryAutoConfig; import dev.langchain4j.provider.AzureModelFactory; import dev.langchain4j.provider.DashscopeModelFactory; +import dev.langchain4j.provider.EmbeddingModelConstant; import dev.langchain4j.provider.InMemoryModelFactory; import dev.langchain4j.provider.OllamaModelFactory; import dev.langchain4j.provider.OpenAiModelFactory; @@ -76,7 +76,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_MODEL_NAME = - new Parameter("s2.embedding.model.name", InMemoryAutoConfig.BGE_SMALL_ZH, + new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH, "ModelName", "", "string", "向量模型配置", null, getDependency(EMBEDDING_MODEL_PROVIDER.getName(), @@ -90,7 +90,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig { ZhipuModelFactory.PROVIDER ), ImmutableMap.of( - InMemoryModelFactory.PROVIDER, InMemoryAutoConfig.BGE_SMALL_ZH, + InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH, OpenAiModelFactory.PROVIDER, "text-embedding-ada-002", OllamaModelFactory.PROVIDER, "all-minilm", AzureModelFactory.PROVIDER, "text-embedding-ada-002", diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java index ab79ea5d7..03d1c0ebc 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryAutoConfig.java @@ -5,6 +5,7 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel; +import dev.langchain4j.provider.EmbeddingModelConstant; import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import org.apache.commons.lang3.StringUtils; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -17,10 +18,6 @@ import static dev.langchain4j.inmemory.spring.Properties.PREFIX; @Configuration @EnableConfigurationProperties(Properties.class) public class InMemoryAutoConfig { - - public static final String BGE_SMALL_ZH = "bge-small-zh"; - public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q"; - @Bean @ConditionalOnProperty(PREFIX + ".embedding-store.persist-path") EmbeddingStoreFactory inMemoryChatModel(Properties properties) { @@ -37,10 +34,10 @@ public class InMemoryAutoConfig { return new S2OnnxEmbeddingModel(modelPath, vocabularyPath); } String modelName = embeddingModelProperties.getModelName(); - if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) { + if (EmbeddingModelConstant.BGE_SMALL_ZH.equalsIgnoreCase(modelName)) { return new BgeSmallZhEmbeddingModel(); } - if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) { + if (EmbeddingModelConstant.ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) { return new AllMiniLmL6V2QuantizedEmbeddingModel(); } return new BgeSmallZhEmbeddingModel(); diff --git a/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java b/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java index 92f7374ce..87bec18ca 100644 --- a/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java +++ b/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java @@ -8,6 +8,7 @@ import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Objects; /** * An embedding model that runs within your Java application's process. @@ -18,36 +19,42 @@ import java.nio.file.Paths; * Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel. */ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { + private static volatile OnnxBertBiEncoder cachedModel; + private static volatile String cachedModelPath; + private static volatile String cachedVocabularyPath; - private static OnnxBertBiEncoder model = null; - - /** - * @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx"). - */ public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) { - URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt"); - if (StringUtils.isNotBlank(vocabularyPath)) { - try { - resource = Paths.get(vocabularyPath).toUri().toURL(); - } catch (MalformedURLException e) { - throw new RuntimeException(e); + if (shouldReloadModel(pathToModel, vocabularyPath)) { + synchronized (S2OnnxEmbeddingModel.class) { + if (shouldReloadModel(pathToModel, vocabularyPath)) { + URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt"); + if (StringUtils.isNotBlank(vocabularyPath)) { + try { + resource = Paths.get(vocabularyPath).toUri().toURL(); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } + } + cachedModel = loadFromFileSystem(Paths.get(pathToModel), resource); + cachedModelPath = pathToModel; + cachedVocabularyPath = vocabularyPath; + } } } - if (model == null) { - model = loadFromFileSystem(Paths.get(pathToModel), resource); - } } - /** - * @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx"). - */ public S2OnnxEmbeddingModel(String pathToModel) { this(pathToModel, null); } @Override protected OnnxBertBiEncoder model() { - return model; + return cachedModel; + } + + private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) { + return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel) + || !Objects.equals(cachedVocabularyPath, vocabularyPath); } static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) { @@ -61,4 +68,4 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { throw new RuntimeException(e); } } -} +} \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/provider/EmbeddingModelConstant.java b/common/src/main/java/dev/langchain4j/provider/EmbeddingModelConstant.java new file mode 100644 index 000000000..b9b3eda94 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/provider/EmbeddingModelConstant.java @@ -0,0 +1,16 @@ +package dev.langchain4j.provider; + +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import org.springframework.stereotype.Service; + +@Service +public class EmbeddingModelConstant { + + public static final String BGE_SMALL_ZH = "bge-small-zh"; + public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q"; + public static final EmbeddingModel BGE_SMALL_ZH_MODEL = new BgeSmallZhEmbeddingModel(); + public static final EmbeddingModel ALL_MINI_LM_L6_V2_MODEL = new AllMiniLmL6V2QuantizedEmbeddingModel(); + +} diff --git a/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java b/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java index e532cff40..e8c074bcb 100644 --- a/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java @@ -3,17 +3,12 @@ package dev.langchain4j.provider; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; -import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.InitializingBean; import org.springframework.stereotype.Service; -import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.ALL_MINILM_L6_V2; -import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH; - @Service public class InMemoryModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "IN_MEMORY"; @@ -31,13 +26,13 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean { return new S2OnnxEmbeddingModel(modelPath, vocabularyPath); } String modelName = embeddingModel.getModelName(); - if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) { - return new BgeSmallZhEmbeddingModel(); + if (EmbeddingModelConstant.BGE_SMALL_ZH.equalsIgnoreCase(modelName)) { + return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL; } - if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) { - return new AllMiniLmL6V2QuantizedEmbeddingModel(); + if (EmbeddingModelConstant.ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) { + return EmbeddingModelConstant.ALL_MINI_LM_L6_V2_MODEL; } - return new BgeSmallZhEmbeddingModel(); + return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL; } @Override diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java index 031d8cdf8..e1082b934 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java @@ -150,8 +150,8 @@ public class S2SemanticLayerService implements SemanticLayerService { //2.query from cache String cacheKey = queryCache.getCacheKey(queryReq); - log.debug("cacheKey:{}", cacheKey); Object query = queryCache.query(queryReq, cacheKey); + log.info("cacheKey:{},query:{}", cacheKey, query); if (Objects.nonNull(query)) { SemanticQueryResp queryResp = (SemanticQueryResp) query; queryResp.setUseCache(true); @@ -495,7 +495,7 @@ public class S2SemanticLayerService implements SemanticLayerService { } private SemanticQueryResp getQueryResultWithSchemaResp(EntityInfo entityInfo, - DataSetSchema dataSetSchema, User user) { + DataSetSchema dataSetSchema, User user) { SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); semanticParseInfo.setDataSet(dataSetSchema.getDataSet()); semanticParseInfo.setQueryType(QueryType.DETAIL);