diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java index 0d6157f92..66465c5b6 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java @@ -14,6 +14,8 @@ import java.util.List; @Slf4j public class ChatModelParameterConfig extends ParameterConfig { + + public static final Parameter CHAT_MODEL_PROVIDER = new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER, "接口协议", "", @@ -21,17 +23,17 @@ public class ChatModelParameterConfig extends ParameterConfig { Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER)); public static final Parameter CHAT_MODEL_BASE_URL = - new Parameter("s2.chat.model.base.url", "", + new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1", "BaseUrl", "", "string", "对话模型配置"); public static final Parameter CHAT_MODEL_API_KEY = - new Parameter("s2.chat.model.api.key", "", + new Parameter("s2.chat.model.api.key", "demo", "ApiKey", "", "string", "对话模型配置"); public static final Parameter CHAT_MODEL_NAME = - new Parameter("s2.chat.model.name", "", + new Parameter("s2.chat.model.name", "gpt-3.5-turbo", "ModelName", "", "string", "对话模型配置"); diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java index cecfedb7e..ffd9f8f52 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.common.config; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.Parameter; +import dev.langchain4j.inmemory.spring.InMemoryAutoConfig; import dev.langchain4j.provider.AzureModelFactory; import dev.langchain4j.provider.DashscopeModelFactory; import dev.langchain4j.provider.InMemoryModelFactory; @@ -20,7 +21,7 @@ import java.util.List; public class EmbeddingModelParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_MODEL_PROVIDER = - new Parameter("s2.embedding.model.provider", "", + new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, "接口协议", "", "string", "向量模型配置", Lists.newArrayList(InMemoryModelFactory.PROVIDER, @@ -43,9 +44,10 @@ public class EmbeddingModelParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_MODEL_NAME = - new Parameter("s2.embedding.model.name", "", + new Parameter("s2.embedding.model.name", InMemoryAutoConfig.BGE_SMALL_ZH, "ModelName", "", - "string", "向量模型配置"); + "string", "向量模型配置", + Lists.newArrayList(InMemoryAutoConfig.BGE_SMALL_ZH, InMemoryAutoConfig.ALL_MINILM_L6_V2)); public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path", "", diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java index 2861061c0..a14662bc1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.common.config; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.Parameter; +import dev.langchain4j.store.embedding.EmbeddingStoreType; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -12,9 +13,12 @@ import java.util.List; @Slf4j public class EmbeddingStoreParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_STORE_PROVIDER = - new Parameter("s2.embedding.store.provider", "", + new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型", "", - "string", "向量库配置"); + "string", "向量库配置", + Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(), + EmbeddingStoreType.MILVUS.name(), + EmbeddingStoreType.CHROMA.name())); public static final Parameter EMBEDDING_STORE_BASE_URL = new Parameter("s2.embedding.store.base.url", "", diff --git a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java index a2e7aeed5..11d4e6940 100644 --- a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java +++ b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java @@ -1,5 +1,6 @@ package dev.langchain4j.provider; +import com.tencent.supersonic.common.config.ChatModelParameterConfig; import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; @@ -18,11 +19,16 @@ public class ModelProvider { factories.put(provider, modelFactory); } + public static ChatLanguageModel getChatModel() { + return getChatModel(null); + } + public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) { - if (modelConfig == null - || StringUtils.isBlank(modelConfig.getProvider()) + if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider()) || StringUtils.isBlank(modelConfig.getBaseUrl())) { - return ContextUtils.getBean(ChatLanguageModel.class); + ChatModelParameterConfig parameterConfig = ContextUtils.getBean( + ChatModelParameterConfig.class); + modelConfig = parameterConfig.convert(); } ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase()); if (modelFactory != null) { @@ -33,15 +39,14 @@ public class ModelProvider { } public static EmbeddingModel getEmbeddingModel() { - EmbeddingModelParameterConfig parameterConfig = ContextUtils.getBean( - EmbeddingModelParameterConfig.class); - EmbeddingModelConfig embeddingModelConfig = parameterConfig.convert(); - return getEmbeddingModel(embeddingModelConfig); + return getEmbeddingModel(null); } public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) { if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) { - return ContextUtils.getBean(EmbeddingModel.class); + EmbeddingModelParameterConfig parameterConfig = ContextUtils.getBean( + EmbeddingModelParameterConfig.class); + embeddingModel = parameterConfig.convert(); } ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase()); if (modelFactory != null) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelper.java index 61f2daa09..6cf0c31b1 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelper.java @@ -5,8 +5,8 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; +import dev.langchain4j.provider.ModelProvider; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; @@ -14,11 +14,9 @@ import org.springframework.stereotype.Component; @Slf4j public class AliasGenerateHelper { - @Autowired - private ChatLanguageModel chatLanguageModel; - public String getChatCompletion(String message) { SystemMessage from = SystemMessage.from(message); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(); Response response = chatLanguageModel.generate(from); log.info("message:{}\n response:{}", message, response); return response.content().text(); diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 4ec5e5d06..21fe6b2e2 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -80,15 +80,4 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ - com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor - - -### common SPIs - -org.springframework.boot.autoconfigure.EnableAutoConfiguration=\ - dev.langchain4j.spring.LangChain4jAutoConfig,\ - dev.langchain4j.openai.spring.AutoConfig,\ - dev.langchain4j.ollama.spring.AutoConfig,\ - dev.langchain4j.azure.openai.spring.AutoConfig,\ - dev.langchain4j.azure.aisearch.spring.AutoConfig,\ - dev.langchain4j.anthropic.spring.AutoConfig \ No newline at end of file + com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index cf8458502..9e1d79013 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -11,7 +11,4 @@ spring: h2: console: path: /h2-console/semantic - enabled: true - config: - import: - - classpath:langchain4j-local.yaml \ No newline at end of file + enabled: true \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/application-prd.yaml b/launchers/standalone/src/main/resources/application-prd.yaml index 87b989dcb..a78d207ea 100644 --- a/launchers/standalone/src/main/resources/application-prd.yaml +++ b/launchers/standalone/src/main/resources/application-prd.yaml @@ -3,7 +3,4 @@ spring: url: jdbc:mysql://${DB_HOST}:${DB_PORT:3306}/${DB_NAME}?useUnicode=true&characterEncoding=UTF-8&useSSL=false&allowMultiQueries=true&allowPublicKeyRetrieval=true username: ${DB_USERNAME} password: ${DB_PASSWORD} - driver-class-name: com.mysql.jdbc.Driver - config: - import: - - classpath:langchain4j-prd.yaml \ No newline at end of file + driver-class-name: com.mysql.jdbc.Driver \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/langchain4j-local.yaml b/launchers/standalone/src/main/resources/langchain4j-local.yaml deleted file mode 100644 index f07915f81..000000000 --- a/launchers/standalone/src/main/resources/langchain4j-local.yaml +++ /dev/null @@ -1,42 +0,0 @@ -langchain4j: - # Replace `open_ai` with ollama/zhipu/azure/dashscope as needed. - # Note: - # 1. `open_ai` is commonly used to connect to cloud-based models; - # 2. `ollama` is commonly used to connect to local models. - open-ai: - chat-model: - # It is recommended to replace with your API key in production. - # Note: The default API key `demo` is provided by langchain4j community - # which limits 1000 tokens per request. - base-url: ${OPENAI_API_BASE:https://api.openai.com/v1} - api-key: ${OPENAI_API_KEY:demo} - model-name: ${OPENAI_MODEL_NAME:gpt-3.5-turbo} - temperature: ${OPENAI_TEMPERATURE:0.0} - timeout: ${OPENAI_TIMEOUT:PT60S} - - # embedding-model: - # base-url: https://api.openai.com/v1 - # api-key: demo - # model-name: text-embedding-3-small - # timeout: PT60S - - in-memory: - embedding-model: - model-name: bge-small-zh - - embedding-store: - persist-path: /tmp - -# chroma: -# embedding-store: -# baseUrl: http://0.0.0.0:8000 -# timeout: 120s - -# milvus: -# embedding-store: -# host: localhost -# port: 2379 -# uri: http://0.0.0.0:19530 -# token: demo -# dimension: 512 -# timeout: 120s \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/langchain4j-prd.yaml b/launchers/standalone/src/main/resources/langchain4j-prd.yaml deleted file mode 100644 index 89f329e0f..000000000 --- a/launchers/standalone/src/main/resources/langchain4j-prd.yaml +++ /dev/null @@ -1,42 +0,0 @@ -langchain4j: - # Replace `open_ai` with ollama/zhipu/azure/dashscope as needed. - # Note: - # 1. `open_ai` is commonly used to connect to cloud-based models; - # 2. `ollama` is commonly used to connect to local models. - open-ai: - chat-model: - # It is recommended to replace with your API key in production. - # Note: The default API key `demo` is provided by langchain4j community - # which limits 1000 tokens per request. - base-url: ${OPENAI_API_BASE:https://api.openai.com/v1} - api-key: ${OPENAI_API_KEY:demo} - model-name: ${OPENAI_MODEL_NAME:gpt-3.5-turbo} - temperature: ${OPENAI_TEMPERATURE:0.0} - timeout: ${OPENAI_TIMEOUT:PT60S} - - # embedding-model: - # base-url: https://api.openai.com/v1 - # api-key: demo - # model-name: text-embedding-3-small - # timeout: PT60S - - in-memory: - embedding-model: - model-name: bge-small-zh - -# embedding-store: -# persist-path: /tmp - - chroma: - embedding-store: - baseUrl: http://${CHROMA_HOST}:8000 - timeout: 120s - -# milvus: -# embedding-store: -# host: localhost -# port: 2379 -# uri: http://0.0.0.0:19530 -# token: demo -# dimension: 512 -# timeout: 120s \ No newline at end of file diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index cf8458502..9e1d79013 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -11,7 +11,4 @@ spring: h2: console: path: /h2-console/semantic - enabled: true - config: - import: - - classpath:langchain4j-local.yaml \ No newline at end of file + enabled: true \ No newline at end of file diff --git a/launchers/standalone/src/test/resources/langchain4j-local.yaml b/launchers/standalone/src/test/resources/langchain4j-local.yaml deleted file mode 100644 index f07915f81..000000000 --- a/launchers/standalone/src/test/resources/langchain4j-local.yaml +++ /dev/null @@ -1,42 +0,0 @@ -langchain4j: - # Replace `open_ai` with ollama/zhipu/azure/dashscope as needed. - # Note: - # 1. `open_ai` is commonly used to connect to cloud-based models; - # 2. `ollama` is commonly used to connect to local models. - open-ai: - chat-model: - # It is recommended to replace with your API key in production. - # Note: The default API key `demo` is provided by langchain4j community - # which limits 1000 tokens per request. - base-url: ${OPENAI_API_BASE:https://api.openai.com/v1} - api-key: ${OPENAI_API_KEY:demo} - model-name: ${OPENAI_MODEL_NAME:gpt-3.5-turbo} - temperature: ${OPENAI_TEMPERATURE:0.0} - timeout: ${OPENAI_TIMEOUT:PT60S} - - # embedding-model: - # base-url: https://api.openai.com/v1 - # api-key: demo - # model-name: text-embedding-3-small - # timeout: PT60S - - in-memory: - embedding-model: - model-name: bge-small-zh - - embedding-store: - persist-path: /tmp - -# chroma: -# embedding-store: -# baseUrl: http://0.0.0.0:8000 -# timeout: 120s - -# milvus: -# embedding-store: -# host: localhost -# port: 2379 -# uri: http://0.0.0.0:19530 -# token: demo -# dimension: 512 -# timeout: 120s \ No newline at end of file