mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[feature][chat]Refactor chat model config related codes.#1739
This commit is contained in:
@@ -1,131 +0,0 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("ChatModelParameterConfig")
|
||||
@Slf4j
|
||||
public class ChatModelParameterConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter CHAT_MODEL_PROVIDER = new Parameter("s2.chat.model.provider",
|
||||
OpenAiModelFactory.PROVIDER, "接口协议", "", "list", "对话模型配置", getCandidateValues());
|
||||
|
||||
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||
new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL, "BaseUrl",
|
||||
"", "string", "对话模型配置", null, getBaseUrlDependency());
|
||||
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("s2.chat.model.endpoint",
|
||||
"llama_2_70b", "Endpoint", "", "string", "对话模型配置", null, getEndpointDependency());
|
||||
public static final Parameter CHAT_MODEL_API_KEY = new Parameter("s2.chat.model.api.key", DEMO,
|
||||
"ApiKey", "", "password", "对话模型配置", null, getApiKeyDependency());
|
||||
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("s2.chat.model.secretKey",
|
||||
"demo", "SecretKey", "", "password", "对话模型配置", null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_NAME = new Parameter("s2.chat.model.name",
|
||||
"gpt-4o-mini", "ModelName", "", "string", "对话模型配置", null, getModelNameDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_ENABLE_SEARCH =
|
||||
new Parameter("s2.chat.model.enableSearch", "false", "是否启用搜索增强功能,设为false表示不启用", "",
|
||||
"bool", "对话模型配置", null, getEnableSearchDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_TEMPERATURE = new Parameter(
|
||||
"s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置");
|
||||
|
||||
public static final Parameter CHAT_MODEL_TIMEOUT =
|
||||
new Parameter("s2.chat.model.timeout", "60", "超时时间(秒)", "", "number", "对话模型配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
||||
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
|
||||
}
|
||||
|
||||
public ChatModelConfig convert() {
|
||||
String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER);
|
||||
String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL);
|
||||
String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY);
|
||||
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
|
||||
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
|
||||
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
||||
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
|
||||
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
|
||||
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
|
||||
|
||||
return ChatModelConfig.builder().provider(chatModelProvider).baseUrl(chatModelBaseUrl)
|
||||
.apiKey(chatModelApiKey).modelName(chatModelName)
|
||||
.enableSearch(Boolean.valueOf(enableSearch))
|
||||
.temperature(Double.valueOf(chatModelTemperature))
|
||||
.timeOut(Long.valueOf(chatModelTimeout)).endpoint(endpoint).secretKey(secretKey)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static List<String> getCandidateValues() {
|
||||
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
||||
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER,
|
||||
DEMO, ZhipuModelFactory.PROVIDER, DEMO, LocalAiModelFactory.PROVIDER, DEMO,
|
||||
AzureModelFactory.PROVIDER, DEMO, DashscopeModelFactory.PROVIDER, DEMO));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
|
||||
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
||||
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
|
||||
}
|
||||
}
|
||||
@@ -21,30 +21,32 @@ import java.util.List;
|
||||
@Service("EmbeddingModelParameterConfig")
|
||||
@Slf4j
|
||||
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
private static final String MODULE_NAME = "嵌入模型配置";
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
||||
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, "接口协议", "",
|
||||
"list", "向量模型配置", getCandidateValues());
|
||||
"list", MODULE_NAME, getCandidateValues());
|
||||
public static final Parameter EMBEDDING_MODEL_BASE_URL =
|
||||
new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", "向量模型配置",
|
||||
new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
|
||||
null, getBaseUrlDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_API_KEY =
|
||||
new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", "向量模型配置",
|
||||
new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", MODULE_NAME,
|
||||
null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
|
||||
new Parameter("s2.embedding.model.secretKey", "demo", "SecretKey", "", "password",
|
||||
"向量模型配置", null, getSecretKeyDependency());
|
||||
MODULE_NAME, null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
"ModelName", "", "string", "向量模型配置", null, getModelNameDependency());
|
||||
"ModelName", "", "string", MODULE_NAME, null, getModelNameDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path",
|
||||
"", "模型路径", "", "string", "向量模型配置", null, getModelPathDependency());
|
||||
"", "模型路径", "", "string", MODULE_NAME, null, getModelPathDependency());
|
||||
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
|
||||
new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string", "向量模型配置",
|
||||
null, getModelPathDependency());
|
||||
new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string",
|
||||
MODULE_NAME, null, getModelPathDependency());
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
|
||||
@@ -15,32 +15,34 @@ import java.util.List;
|
||||
@Service("EmbeddingStoreParameterConfig")
|
||||
@Slf4j
|
||||
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
private static final String MODULE_NAME = "向量数据库配置";
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter(
|
||||
"s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型",
|
||||
"目前支持三种类型:IN_MEMORY、MILVUS、CHROMA", "list", "向量库配置", getCandidateValues());
|
||||
"目前支持三种类型:IN_MEMORY、MILVUS、CHROMA", "list", MODULE_NAME, getCandidateValues());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
||||
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", "向量库配置", null,
|
||||
getBaseUrlDependency());
|
||||
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
|
||||
null, getBaseUrlDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_API_KEY =
|
||||
new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", "向量库配置", null,
|
||||
getApiKeyDependency());
|
||||
new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", MODULE_NAME,
|
||||
null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
|
||||
new Parameter("s2.embedding.store.persist.path", "", "持久化路径",
|
||||
"默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string",
|
||||
"向量库配置", null, getPathDependency());
|
||||
MODULE_NAME, null, getPathDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_TIMEOUT =
|
||||
new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", "向量库配置");
|
||||
new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", MODULE_NAME);
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_DIMENSION =
|
||||
new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", "向量库配置", null,
|
||||
new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", MODULE_NAME, null,
|
||||
getDimensionDependency());
|
||||
public static final Parameter EMBEDDING_STORE_DATABASE_NAME =
|
||||
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
|
||||
"向量库配置", null, getDatabaseNameDependency());
|
||||
MODULE_NAME, null, getDatabaseNameDependency());
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class PromptConfig {
|
||||
|
||||
private String promptTemplate;
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class VisualConfig {
|
||||
|
||||
private boolean enableSimpleMode;
|
||||
|
||||
private boolean showDebugInfo;
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelParameterConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
@@ -14,6 +13,10 @@ import java.util.Map;
|
||||
|
||||
public class ModelProvider {
|
||||
|
||||
public static final ChatModelConfig DEMO_CHAT_MODEL =
|
||||
ChatModelConfig.builder().provider("open_ai").baseUrl("https://api.openai.com/v1")
|
||||
.apiKey("demo").modelName("gpt-4o-mini").temperature(0.0).timeOut(60L).build();
|
||||
|
||||
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
||||
|
||||
public static void add(String provider, ModelFactory modelFactory) {
|
||||
@@ -27,9 +30,7 @@ public class ModelProvider {
|
||||
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
|
||||
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|
||||
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||
ChatModelParameterConfig parameterConfig =
|
||||
ContextUtils.getBean(ChatModelParameterConfig.class);
|
||||
modelConfig = parameterConfig.convert();
|
||||
modelConfig = DEMO_CHAT_MODEL;
|
||||
}
|
||||
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
|
||||
Reference in New Issue
Block a user