[feature][chat]Refactor chat model config related codes.#1739

This commit is contained in:
jerryjzhang
2024-10-09 17:27:07 +08:00
parent 60b0a1a1a1
commit 248f4f83f6
53 changed files with 275 additions and 251 deletions

View File

@@ -1,131 +0,0 @@
package com.tencent.supersonic.common.config;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.LocalAiModelFactory;
import dev.langchain4j.provider.OllamaModelFactory;
import dev.langchain4j.provider.OpenAiModelFactory;
import dev.langchain4j.provider.QianfanModelFactory;
import dev.langchain4j.provider.ZhipuModelFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("ChatModelParameterConfig")
@Slf4j
public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER = new Parameter("s2.chat.model.provider",
OpenAiModelFactory.PROVIDER, "接口协议", "", "list", "对话模型配置", getCandidateValues());
public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL, "BaseUrl",
"", "string", "对话模型配置", null, getBaseUrlDependency());
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("s2.chat.model.endpoint",
"llama_2_70b", "Endpoint", "", "string", "对话模型配置", null, getEndpointDependency());
public static final Parameter CHAT_MODEL_API_KEY = new Parameter("s2.chat.model.api.key", DEMO,
"ApiKey", "", "password", "对话模型配置", null, getApiKeyDependency());
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("s2.chat.model.secretKey",
"demo", "SecretKey", "", "password", "对话模型配置", null, getSecretKeyDependency());
public static final Parameter CHAT_MODEL_NAME = new Parameter("s2.chat.model.name",
"gpt-4o-mini", "ModelName", "", "string", "对话模型配置", null, getModelNameDependency());
public static final Parameter CHAT_MODEL_ENABLE_SEARCH =
new Parameter("s2.chat.model.enableSearch", "false", "是否启用搜索增强功能设为false表示不启用", "",
"bool", "对话模型配置", null, getEnableSearchDependency());
public static final Parameter CHAT_MODEL_TEMPERATURE = new Parameter(
"s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置");
public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("s2.chat.model.timeout", "60", "超时时间(秒)", "", "number", "对话模型配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
}
public ChatModelConfig convert() {
String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER);
String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL);
String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY);
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
return ChatModelConfig.builder().provider(chatModelProvider).baseUrl(chatModelBaseUrl)
.apiKey(chatModelApiKey).modelName(chatModelName)
.enableSearch(Boolean.valueOf(enableSearch))
.temperature(Double.valueOf(chatModelTemperature))
.timeOut(Long.valueOf(chatModelTimeout)).endpoint(endpoint).secretKey(secretKey)
.build();
}
private static List<String> getCandidateValues() {
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
AzureModelFactory.PROVIDER);
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL));
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER,
DEMO, ZhipuModelFactory.PROVIDER, DEMO, LocalAiModelFactory.PROVIDER, DEMO,
AzureModelFactory.PROVIDER, DEMO, DashscopeModelFactory.PROVIDER, DEMO));
}
private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME));
}
private static List<Parameter.Dependency> getEndpointDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
}
private static List<Parameter.Dependency> getEnableSearchDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
}
private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
}
}

View File

@@ -21,30 +21,32 @@ import java.util.List;
@Service("EmbeddingModelParameterConfig")
@Slf4j
public class EmbeddingModelParameterConfig extends ParameterConfig {
private static final String MODULE_NAME = "嵌入模型配置";
public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, "接口协议", "",
"list", "向量模型配置", getCandidateValues());
"list", MODULE_NAME, getCandidateValues());
public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", "向量模型配置",
new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
null, getBaseUrlDependency());
public static final Parameter EMBEDDING_MODEL_API_KEY =
new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", "向量模型配置",
new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", MODULE_NAME,
null, getApiKeyDependency());
public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
new Parameter("s2.embedding.model.secretKey", "demo", "SecretKey", "", "password",
"向量模型配置", null, getSecretKeyDependency());
MODULE_NAME, null, getSecretKeyDependency());
public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
"ModelName", "", "string", "向量模型配置", null, getModelNameDependency());
"ModelName", "", "string", MODULE_NAME, null, getModelNameDependency());
public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path",
"", "模型路径", "", "string", "向量模型配置", null, getModelPathDependency());
"", "模型路径", "", "string", MODULE_NAME, null, getModelPathDependency());
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string", "向量模型配置",
null, getModelPathDependency());
new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string",
MODULE_NAME, null, getModelPathDependency());
@Override
public List<Parameter> getSysParameters() {

View File

@@ -15,32 +15,34 @@ import java.util.List;
@Service("EmbeddingStoreParameterConfig")
@Slf4j
public class EmbeddingStoreParameterConfig extends ParameterConfig {
private static final String MODULE_NAME = "向量数据库配置";
public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter(
"s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型",
"目前支持三种类型IN_MEMORY、MILVUS、CHROMA", "list", "向量库配置", getCandidateValues());
"目前支持三种类型IN_MEMORY、MILVUS、CHROMA", "list", MODULE_NAME, getCandidateValues());
public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", "向量库配置", null,
getBaseUrlDependency());
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
null, getBaseUrlDependency());
public static final Parameter EMBEDDING_STORE_API_KEY =
new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", "向量库配置", null,
getApiKeyDependency());
new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", MODULE_NAME,
null, getApiKeyDependency());
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
new Parameter("s2.embedding.store.persist.path", "", "持久化路径",
"默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string",
"向量库配置", null, getPathDependency());
MODULE_NAME, null, getPathDependency());
public static final Parameter EMBEDDING_STORE_TIMEOUT =
new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", "向量库配置");
new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", MODULE_NAME);
public static final Parameter EMBEDDING_STORE_DIMENSION =
new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", "向量库配置", null,
new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", MODULE_NAME, null,
getDimensionDependency());
public static final Parameter EMBEDDING_STORE_DATABASE_NAME =
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
"向量库配置", null, getDatabaseNameDependency());
MODULE_NAME, null, getDatabaseNameDependency());
@Override
public List<Parameter> getSysParameters() {

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.common.config;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class PromptConfig {
private String promptTemplate;
}

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.common.config;
import lombok.Data;
@Data
public class VisualConfig {
private boolean enableSimpleMode;
private boolean showDebugInfo;
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelParameterConfig;
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
@@ -14,6 +13,10 @@ import java.util.Map;
public class ModelProvider {
public static final ChatModelConfig DEMO_CHAT_MODEL =
ChatModelConfig.builder().provider("open_ai").baseUrl("https://api.openai.com/v1")
.apiKey("demo").modelName("gpt-4o-mini").temperature(0.0).timeOut(60L).build();
private static final Map<String, ModelFactory> factories = new HashMap<>();
public static void add(String provider, ModelFactory modelFactory) {
@@ -27,9 +30,7 @@ public class ModelProvider {
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
ChatModelParameterConfig parameterConfig =
ContextUtils.getBean(ChatModelParameterConfig.class);
modelConfig = parameterConfig.convert();
modelConfig = DEMO_CHAT_MODEL;
}
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
if (modelFactory != null) {