diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java index 465470379..fb8bc568a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java @@ -4,8 +4,15 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Parameter; +import dev.ai4j.openai4j.chat.ChatCompletionModel; +import dev.langchain4j.model.dashscope.QwenModelName; +import dev.langchain4j.provider.AzureModelFactory; +import dev.langchain4j.provider.DashscopeModelFactory; +import dev.langchain4j.provider.LocalAiModelFactory; import dev.langchain4j.provider.OllamaModelFactory; import dev.langchain4j.provider.OpenAiModelFactory; +import dev.langchain4j.provider.QianfanModelFactory; +import dev.langchain4j.provider.ZhipuModelFactory; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -18,49 +25,36 @@ public class ChatModelParameterConfig extends ParameterConfig { public static final Parameter CHAT_MODEL_PROVIDER = new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER, - "接口协议", "", - "list", "对话模型配置", - getCandidateValues()); - + "接口协议", "", "list", + "对话模型配置", getCandidateValues()); public static final Parameter CHAT_MODEL_BASE_URL = - new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1", + new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL, "BaseUrl", "", "string", - "对话模型配置", null, - getDependency(CHAT_MODEL_PROVIDER.getName(), - getCandidateValues(), - ImmutableMap.of( - OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1", - OllamaModelFactory.PROVIDER, "http://localhost:11434") - ) - ); - + "对话模型配置", null, getBaseUrlDependency()); + public static final Parameter CHAT_MODEL_ENDPOINT = + new Parameter("s2.chat.model.endpoint", "llama_2_70b", + "Endpoint", "", "string", + "对话模型配置", null, getEndpointDependency()); public static final Parameter CHAT_MODEL_API_KEY = - new Parameter("s2.chat.model.api.key", "demo", - "ApiKey", "", - "string", "对话模型配置", null, - getDependency(CHAT_MODEL_PROVIDER.getName(), - Lists.newArrayList(OpenAiModelFactory.PROVIDER), - ImmutableMap.of(OpenAiModelFactory.PROVIDER, "demo")) + new Parameter("s2.chat.model.api.key", DEMO, + "ApiKey", "", "password", + "对话模型配置", null, getApiKeyDependency() ); + public static final Parameter CHAT_MODEL_SECRET_KEY = + new Parameter("s2.chat.model.secretKey", "demo", + "SecretKey", "", "password", + "对话模型配置", null, getSecretKeyDependency()); public static final Parameter CHAT_MODEL_NAME = new Parameter("s2.chat.model.name", "gpt-3.5-turbo", - "ModelName", "", - "string", "对话模型配置", null, - getDependency(CHAT_MODEL_PROVIDER.getName(), - getCandidateValues(), - ImmutableMap.of(OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo", - OllamaModelFactory.PROVIDER, "qwen:0.5b") - )); + "ModelName", "", "string", + "对话模型配置", null, getModelNameDependency()); public static final Parameter CHAT_MODEL_TEMPERATURE = new Parameter("s2.chat.model.temperature", "0.0", "Temperature", "", - "slider", "对话模型配置", null, - getDependency(CHAT_MODEL_PROVIDER.getName(), - getCandidateValues(), - ImmutableMap.of(OpenAiModelFactory.PROVIDER, "0.0", OllamaModelFactory.PROVIDER, "0.0"))); + "slider", "对话模型配置"); public static final Parameter CHAT_MODEL_TIMEOUT = new Parameter("s2.chat.model.timeout", "60", @@ -70,8 +64,9 @@ public class ChatModelParameterConfig extends ParameterConfig { @Override public List getSysParameters() { return Lists.newArrayList( - CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY, - CHAT_MODEL_NAME, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT + CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT, + CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, + CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT ); } @@ -82,6 +77,8 @@ public class ChatModelParameterConfig extends ParameterConfig { String chatModelName = getParameterValue(CHAT_MODEL_NAME); String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE); String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT); + String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT); + String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY); return ChatModelConfig.builder() .provider(chatModelProvider) @@ -90,10 +87,94 @@ public class ChatModelParameterConfig extends ParameterConfig { .modelName(chatModelName) .temperature(Double.valueOf(chatModelTemperature)) .timeOut(Long.valueOf(chatModelTimeout)) + .endpoint(endpoint) + .secretKey(secretKey) .build(); } - private static ArrayList getCandidateValues() { - return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER); + private static List getCandidateValues() { + List candidateValues = getBaseUrlCandidateValues(); + candidateValues.add(AzureModelFactory.PROVIDER); + return candidateValues; + } + + private static ArrayList getBaseUrlCandidateValues() { + return Lists.newArrayList( + OpenAiModelFactory.PROVIDER, + OllamaModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, + ZhipuModelFactory.PROVIDER, + LocalAiModelFactory.PROVIDER, + DashscopeModelFactory.PROVIDER); + } + + private static List getBaseUrlDependency() { + return getDependency(CHAT_MODEL_PROVIDER.getName(), + getBaseUrlCandidateValues(), + ImmutableMap.of( + OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL, + OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL, + QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL, + ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL, + LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL, + DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL) + ); + } + + private static List getApiKeyDependency() { + return getDependency(CHAT_MODEL_PROVIDER.getName(), + Lists.newArrayList( + OpenAiModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, + ZhipuModelFactory.PROVIDER, + LocalAiModelFactory.PROVIDER, + AzureModelFactory.PROVIDER, + DashscopeModelFactory.PROVIDER + ), + ImmutableMap.of( + OpenAiModelFactory.PROVIDER, DEMO, + QianfanModelFactory.PROVIDER, DEMO, + ZhipuModelFactory.PROVIDER, DEMO, + LocalAiModelFactory.PROVIDER, DEMO, + AzureModelFactory.PROVIDER, DEMO, + DashscopeModelFactory.PROVIDER, DEMO + )); + } + + private static List getModelNameDependency() { + return getDependency(CHAT_MODEL_PROVIDER.getName(), + getCandidateValues(), + ImmutableMap.of( + OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo", + OllamaModelFactory.PROVIDER, "qwen:0.5b", + QianfanModelFactory.PROVIDER, "Llama-2-70b-chat", + ZhipuModelFactory.PROVIDER, ChatCompletionModel.GPT_4.toString(), + LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j", + AzureModelFactory.PROVIDER, "gpt-35-turbo", + DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS + ) + ); + } + + private static List getEndpointDependency() { + return getDependency(CHAT_MODEL_PROVIDER.getName(), + Lists.newArrayList( + AzureModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER + ), + ImmutableMap.of( + AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, + QianfanModelFactory.PROVIDER, "llama_2_70b" + ) + ); + } + + private static List getSecretKeyDependency() { + return getDependency(CHAT_MODEL_PROVIDER.getName(), + Lists.newArrayList(QianfanModelFactory.PROVIDER), + ImmutableMap.of( + QianfanModelFactory.PROVIDER, DEMO + ) + ); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java index c40cfb1e6..6082ca6e8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java @@ -21,109 +21,34 @@ import java.util.List; @Service("EmbeddingModelParameterConfig") @Slf4j public class EmbeddingModelParameterConfig extends ParameterConfig { - public static final Parameter EMBEDDING_MODEL_PROVIDER = new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, - "接口协议", "", - "list", "向量模型配置", - getCandidateValues()); + "接口协议", "", "list", + "向量模型配置", getCandidateValues()); public static final Parameter EMBEDDING_MODEL_BASE_URL = new Parameter("s2.embedding.model.base.url", "", - "BaseUrl", "", - "string", "向量模型配置", null, - getDependency(EMBEDDING_MODEL_PROVIDER.getName(), - Lists.newArrayList( - OpenAiModelFactory.PROVIDER, - OllamaModelFactory.PROVIDER, - AzureModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, - ZhipuModelFactory.PROVIDER - ), - ImmutableMap.of( - OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1", - OllamaModelFactory.PROVIDER, "http://localhost:11434", - AzureModelFactory.PROVIDER, "https://xxxx.openai.azure.com/", - DashscopeModelFactory.PROVIDER, "https://dashscope.aliyuncs.com/compatible-mode/v1", - QianfanModelFactory.PROVIDER, "https://aip.baidubce.com", - ZhipuModelFactory.PROVIDER, "https://open.bigmodel.cn/api/paas/v4/" - ) - ) + "BaseUrl", "", "string", + "向量模型配置", null, getBaseUrlDependency() ); public static final Parameter EMBEDDING_MODEL_API_KEY = new Parameter("s2.embedding.model.api.key", "", - "ApiKey", "", - "string", "向量模型配置", null, - getDependency(EMBEDDING_MODEL_PROVIDER.getName(), - Lists.newArrayList( - OpenAiModelFactory.PROVIDER, - OllamaModelFactory.PROVIDER, - AzureModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, - ZhipuModelFactory.PROVIDER - ), - ImmutableMap.of( - OpenAiModelFactory.PROVIDER, "demo", - OllamaModelFactory.PROVIDER, "demo", - AzureModelFactory.PROVIDER, "demo", - DashscopeModelFactory.PROVIDER, "demo", - QianfanModelFactory.PROVIDER, "demo", - ZhipuModelFactory.PROVIDER, "demo" - ) - )); - + "ApiKey", "", "password", + "向量模型配置", null, getApiKeyDependency()); public static final Parameter EMBEDDING_MODEL_NAME = new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH, - "ModelName", "", - "string", "向量模型配置", null, - getDependency(EMBEDDING_MODEL_PROVIDER.getName(), - Lists.newArrayList( - InMemoryModelFactory.PROVIDER, - OpenAiModelFactory.PROVIDER, - OllamaModelFactory.PROVIDER, - AzureModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, - ZhipuModelFactory.PROVIDER - ), - ImmutableMap.of( - InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH, - OpenAiModelFactory.PROVIDER, "text-embedding-ada-002", - OllamaModelFactory.PROVIDER, "all-minilm", - AzureModelFactory.PROVIDER, "text-embedding-ada-002", - DashscopeModelFactory.PROVIDER, "text-embedding-ada-002", - QianfanModelFactory.PROVIDER, "text-embedding-ada-002", - ZhipuModelFactory.PROVIDER, "text-embedding-ada-002" - ) - )); + "ModelName", "", "string", + "向量模型配置", null, getModelNameDependency()); public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path", "", - "模型路径", "", - "string", "向量模型配置", null, - getDependency(EMBEDDING_MODEL_PROVIDER.getName(), - Lists.newArrayList( - InMemoryModelFactory.PROVIDER - ), - ImmutableMap.of( - InMemoryModelFactory.PROVIDER, "" - ) - )); - + "模型路径", "", "string", + "向量模型配置", null, getModelPathDependency()); public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH = new Parameter("s2.embedding.model.vocabulary.path", "", - "词汇表路径", "", - "string", "向量模型配置", null, - getDependency(EMBEDDING_MODEL_PROVIDER.getName(), - Lists.newArrayList( - InMemoryModelFactory.PROVIDER - ), - ImmutableMap.of( - InMemoryModelFactory.PROVIDER, "" - ))); + "词汇表路径", "", "string", + "向量模型配置", null, getModelPathDependency()); @Override public List getSysParameters() { @@ -152,13 +77,80 @@ public class EmbeddingModelParameterConfig extends ParameterConfig { } private static ArrayList getCandidateValues() { - return Lists.newArrayList(InMemoryModelFactory.PROVIDER, + return Lists.newArrayList( + InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER, AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, - ZhipuModelFactory.PROVIDER); + ZhipuModelFactory.PROVIDER + ); } + private static List getBaseUrlDependency() { + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), + Lists.newArrayList(OpenAiModelFactory.PROVIDER, + OllamaModelFactory.PROVIDER, + AzureModelFactory.PROVIDER, + DashscopeModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, + ZhipuModelFactory.PROVIDER), + ImmutableMap.of( + OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL, + OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL, + AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, + DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL, + QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL, + ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL + ) + ); + } + + private static List getApiKeyDependency() { + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), + Lists.newArrayList(OpenAiModelFactory.PROVIDER, + OllamaModelFactory.PROVIDER, + AzureModelFactory.PROVIDER, + DashscopeModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, + ZhipuModelFactory.PROVIDER), + ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, + OllamaModelFactory.PROVIDER, DEMO, + AzureModelFactory.PROVIDER, DEMO, + DashscopeModelFactory.PROVIDER, DEMO, + QianfanModelFactory.PROVIDER, DEMO, + ZhipuModelFactory.PROVIDER, DEMO) + ); + } + + private static List getModelNameDependency() { + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), + Lists.newArrayList( + InMemoryModelFactory.PROVIDER, + OpenAiModelFactory.PROVIDER, + OllamaModelFactory.PROVIDER, + AzureModelFactory.PROVIDER, + DashscopeModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, + ZhipuModelFactory.PROVIDER + ), + ImmutableMap.of( + InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH, + OpenAiModelFactory.PROVIDER, "text-embedding-ada-002", + OllamaModelFactory.PROVIDER, "all-minilm", + AzureModelFactory.PROVIDER, "text-embedding-ada-002", + DashscopeModelFactory.PROVIDER, "text-embedding-ada-002", + QianfanModelFactory.PROVIDER, "text-embedding-ada-002", + ZhipuModelFactory.PROVIDER, "text-embedding-ada-002" + ) + ); + } + + private static List getModelPathDependency() { + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), + Lists.newArrayList(InMemoryModelFactory.PROVIDER), + ImmutableMap.of(InMemoryModelFactory.PROVIDER, "") + ); + } } diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java index 2eaf942d7..18bdb2913 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java @@ -9,6 +9,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; +import java.util.ArrayList; import java.util.List; @Service("EmbeddingStoreParameterConfig") @@ -16,50 +17,23 @@ import java.util.List; public class EmbeddingStoreParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), - "向量库类型", "", - "list", "向量库配置", - Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(), - EmbeddingStoreType.MILVUS.name(), - EmbeddingStoreType.CHROMA.name())); + "向量库类型", "", "list", + "向量库配置", getCandidateValues()); public static final Parameter EMBEDDING_STORE_BASE_URL = new Parameter("s2.embedding.store.base.url", "", - "BaseUrl", "", - "string", "向量库配置", null, - getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList( - EmbeddingStoreType.MILVUS.name(), - EmbeddingStoreType.CHROMA.name() - ), - ImmutableMap.of( - EmbeddingStoreType.MILVUS.name(), "http://localhost:19530", - EmbeddingStoreType.CHROMA.name(), "http://localhost:8000" - ) - )); + "BaseUrl", "", "string", + "向量库配置", null, getBaseUrlDependency()); public static final Parameter EMBEDDING_STORE_API_KEY = new Parameter("s2.embedding.store.api.key", "", - "ApiKey", "", - "string", "向量库配置", null, - getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList( - EmbeddingStoreType.MILVUS.name() - ), - ImmutableMap.of( - EmbeddingStoreType.MILVUS.name(), "demo" - ) - )); + "ApiKey", "", "password", + "向量库配置", null, getApiKeyDependency()); + public static final Parameter EMBEDDING_STORE_PERSIST_PATH = new Parameter("s2.embedding.store.persist.path", "/tmp", - "持久化路径", "", - "string", "向量库配置", null, - getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList( - EmbeddingStoreType.IN_MEMORY.name() - ), - ImmutableMap.of( - EmbeddingStoreType.IN_MEMORY.name(), "/tmp" - ))); + "持久化路径", "", "string", + "向量库配置", null, getPathDependency()); public static final Parameter EMBEDDING_STORE_TIMEOUT = new Parameter("s2.embedding.store.timeout", "60", @@ -68,16 +42,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_STORE_DIMENSION = new Parameter("s2.embedding.store.dimension", "", - "纬度", "", - "number", "向量库配置", null, - getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList( - EmbeddingStoreType.MILVUS.name() - ), - ImmutableMap.of( - EmbeddingStoreType.MILVUS.name(), "384" - ) - )); + "纬度", "", "number", + "向量库配置", null, getDimensionDependency()); @Override public List getSysParameters() { @@ -97,13 +63,50 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) { dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION)); } - return EmbeddingStoreConfig.builder() - .provider(provider) - .baseUrl(baseUrl) - .apiKey(apiKey) - .persistPath(persistPath) - .timeOut(Long.valueOf(timeOut)) - .dimension(dimension) - .build(); + return EmbeddingStoreConfig.builder().provider(provider) + .baseUrl(baseUrl).apiKey(apiKey).persistPath(persistPath) + .timeOut(Long.valueOf(timeOut)).dimension(dimension).build(); + } + + private static ArrayList getCandidateValues() { + return Lists.newArrayList( + EmbeddingStoreType.IN_MEMORY.name(), + EmbeddingStoreType.MILVUS.name(), + EmbeddingStoreType.CHROMA.name()); + } + + private static List getBaseUrlDependency() { + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), + Lists.newArrayList( + EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()), + ImmutableMap.of( + EmbeddingStoreType.MILVUS.name(), "http://localhost:19530", + EmbeddingStoreType.CHROMA.name(), "http://localhost:8000" + ) + ); + } + + private static List getApiKeyDependency() { + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), + Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), + ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO) + ); + } + + private static List getPathDependency() { + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), + Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()), + ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), "/tmp")); + } + + private static List getDimensionDependency() { + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), + Lists.newArrayList( + EmbeddingStoreType.MILVUS.name() + ), + ImmutableMap.of( + EmbeddingStoreType.MILVUS.name(), "384" + ) + ); } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ParameterConfig.java index d7339ae9a..d1d8ab40c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/ParameterConfig.java @@ -14,7 +14,7 @@ import java.util.Map; @Service public abstract class ParameterConfig { - + public static final String DEMO = "demo"; @Autowired private SystemConfigService sysConfigService; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java index 17711e0b5..e96478e29 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java @@ -20,6 +20,12 @@ public class ChatModelConfig implements Serializable { private String modelName; private Double temperature = 0.0d; private Long timeOut = 60L; + private String endpoint; + private String secretKey; + private Double topP; + private Integer maxRetries = 3; + private Boolean logRequests = false; + private Boolean logResponses = false; public String keyDecrypt() { return AESEncryptionUtil.aesDecryptECB(getApiKey()); diff --git a/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java b/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java index 9f23d54c1..12f44cae8 100644 --- a/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java @@ -14,15 +14,19 @@ import java.time.Duration; @Service public class AzureModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "AZURE"; + public static final String DEFAULT_BASE_URL = "https://xxxx.openai.azure.com/"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder() - .endpoint(modelConfig.getBaseUrl()) + .endpoint(modelConfig.getEndpoint()) .apiKey(modelConfig.getApiKey()) .deploymentName(modelConfig.getModelName()) .temperature(modelConfig.getTemperature()) - .timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut())); + .maxRetries(modelConfig.getMaxRetries()) + .topP(modelConfig.getTopP()) + .timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut())) + .logRequestsAndResponses(modelConfig.getLogRequests() != null && modelConfig.getLogResponses()); return builder.build(); } diff --git a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java index db391abbf..8460cd98a 100644 --- a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java @@ -12,6 +12,7 @@ import org.springframework.stereotype.Service; @Service public class DashscopeModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "DASHSCOPE"; + public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { @@ -21,6 +22,7 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean { .modelName(modelConfig.getModelName()) .temperature(modelConfig.getTemperature() == null ? 0L : modelConfig.getTemperature().floatValue()) + .topP(modelConfig.getTopP()) .build(); } diff --git a/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java index ea95287cd..15723b30d 100644 --- a/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java @@ -14,7 +14,7 @@ import java.time.Duration; @Service public class LocalAiModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "LOCAL_AI"; - + public static final String DEFAULT_BASE_URL = "http://localhost:8080"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { return LocalAiChatModel @@ -23,6 +23,10 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean { .modelName(modelConfig.getModelName()) .temperature(modelConfig.getTemperature()) .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) + .topP(modelConfig.getTopP()) + .logRequests(modelConfig.getLogRequests()) + .logResponses(modelConfig.getLogResponses()) + .maxRetries(modelConfig.getMaxRetries()) .build(); } @@ -31,6 +35,9 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean { return LocalAiEmbeddingModel.builder() .baseUrl(embeddingModel.getBaseUrl()) .modelName(embeddingModel.getModelName()) + .maxRetries(embeddingModel.getMaxRetries()) + .logRequests(embeddingModel.getLogRequests()) + .logResponses(embeddingModel.getLogResponses()) .build(); } diff --git a/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java index 4eb24aa98..ab8c225d0 100644 --- a/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java @@ -13,7 +13,9 @@ import java.time.Duration; @Service public class OllamaModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "OLLAMA"; + public static final String DEFAULT_BASE_URL = "http://localhost:11434"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { @@ -23,6 +25,10 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean { .modelName(modelConfig.getModelName()) .temperature(modelConfig.getTemperature()) .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) + .topP(modelConfig.getTopP()) + .maxRetries(modelConfig.getMaxRetries()) + .logRequests(modelConfig.getLogRequests()) + .logResponses(modelConfig.getLogResponses()) .build(); } diff --git a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java index 481293b22..505bf3b03 100644 --- a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java @@ -13,7 +13,9 @@ import java.time.Duration; @Service public class OpenAiModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "OPEN_AI"; + public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { @@ -23,7 +25,11 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean { .modelName(modelConfig.getModelName()) .apiKey(modelConfig.keyDecrypt()) .temperature(modelConfig.getTemperature()) + .topP(modelConfig.getTopP()) + .maxRetries(modelConfig.getMaxRetries()) .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) + .logRequests(modelConfig.getLogRequests()) + .logResponses(modelConfig.getLogResponses()) .build(); } diff --git a/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java b/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java index 60bd411cb..74f81408b 100644 --- a/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java @@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.qianfan.QianfanChatModel; import dev.langchain4j.model.qianfan.QianfanEmbeddingModel; import org.springframework.beans.factory.InitializingBean; import org.springframework.stereotype.Service; @@ -12,10 +13,22 @@ import org.springframework.stereotype.Service; public class QianfanModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "QIANFAN"; + public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - return null; + return QianfanChatModel.builder() + .baseUrl(modelConfig.getBaseUrl()) + .apiKey(modelConfig.getApiKey()) + .secretKey(modelConfig.getSecretKey()) + .endpoint(modelConfig.getEndpoint()) + .modelName(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature()) + .topP(modelConfig.getTopP()) + .maxRetries(modelConfig.getMaxRetries()) + .logRequests(modelConfig.getLogRequests()) + .logResponses(modelConfig.getLogResponses()) + .build(); } @Override diff --git a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java index d37d26408..10fb5f0c7 100644 --- a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java @@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.zhipu.ZhipuAiChatModel; import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel; import org.springframework.beans.factory.InitializingBean; import org.springframework.stereotype.Service; @@ -11,10 +12,20 @@ import org.springframework.stereotype.Service; @Service public class ZhipuModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "ZHIPU"; + public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - return null; + return ZhipuAiChatModel.builder() + .baseUrl(modelConfig.getBaseUrl()) + .apiKey(modelConfig.getApiKey()) + .model(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature()) + .topP(modelConfig.getTopP()) + .maxRetries(modelConfig.getMaxRetries()) + .logRequests(modelConfig.getLogRequests()) + .logResponses(modelConfig.getLogResponses()) + .build(); } @Override