mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(chat) Support large models qianfan, zhipu, Azure, LocalAi, Dashscope, and handle the apiKey configuration as hidden. (#1552)
This commit is contained in:
@@ -4,8 +4,15 @@ import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionModel;
|
||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -18,49 +25,36 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
|
||||
"接口协议", "",
|
||||
"list", "对话模型配置",
|
||||
getCandidateValues());
|
||||
|
||||
"接口协议", "", "list",
|
||||
"对话模型配置", getCandidateValues());
|
||||
|
||||
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||
new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1",
|
||||
new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
"BaseUrl", "", "string",
|
||||
"对话模型配置", null,
|
||||
getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
|
||||
OllamaModelFactory.PROVIDER, "http://localhost:11434")
|
||||
)
|
||||
);
|
||||
|
||||
"对话模型配置", null, getBaseUrlDependency());
|
||||
public static final Parameter CHAT_MODEL_ENDPOINT =
|
||||
new Parameter("s2.chat.model.endpoint", "llama_2_70b",
|
||||
"Endpoint", "", "string",
|
||||
"对话模型配置", null, getEndpointDependency());
|
||||
public static final Parameter CHAT_MODEL_API_KEY =
|
||||
new Parameter("s2.chat.model.api.key", "demo",
|
||||
"ApiKey", "",
|
||||
"string", "对话模型配置", null,
|
||||
getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "demo"))
|
||||
new Parameter("s2.chat.model.api.key", DEMO,
|
||||
"ApiKey", "", "password",
|
||||
"对话模型配置", null, getApiKeyDependency()
|
||||
);
|
||||
public static final Parameter CHAT_MODEL_SECRET_KEY =
|
||||
new Parameter("s2.chat.model.secretKey", "demo",
|
||||
"SecretKey", "", "password",
|
||||
"对话模型配置", null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_NAME =
|
||||
new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
|
||||
"ModelName", "",
|
||||
"string", "对话模型配置", null,
|
||||
getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
|
||||
OllamaModelFactory.PROVIDER, "qwen:0.5b")
|
||||
));
|
||||
"ModelName", "", "string",
|
||||
"对话模型配置", null, getModelNameDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||
new Parameter("s2.chat.model.temperature", "0.0",
|
||||
"Temperature", "",
|
||||
"slider", "对话模型配置", null,
|
||||
getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "0.0", OllamaModelFactory.PROVIDER, "0.0")));
|
||||
"slider", "对话模型配置");
|
||||
|
||||
public static final Parameter CHAT_MODEL_TIMEOUT =
|
||||
new Parameter("s2.chat.model.timeout", "60",
|
||||
@@ -70,8 +64,9 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY,
|
||||
CHAT_MODEL_NAME, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
||||
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
||||
CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
||||
);
|
||||
}
|
||||
|
||||
@@ -82,6 +77,8 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
|
||||
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
|
||||
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
||||
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
|
||||
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
|
||||
|
||||
return ChatModelConfig.builder()
|
||||
.provider(chatModelProvider)
|
||||
@@ -90,10 +87,94 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
.modelName(chatModelName)
|
||||
.temperature(Double.valueOf(chatModelTemperature))
|
||||
.timeOut(Long.valueOf(chatModelTimeout))
|
||||
.endpoint(endpoint)
|
||||
.secretKey(secretKey)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static ArrayList<String> getCandidateValues() {
|
||||
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER);
|
||||
private static List<String> getCandidateValues() {
|
||||
List<String> candidateValues = getBaseUrlCandidateValues();
|
||||
candidateValues.add(AzureModelFactory.PROVIDER);
|
||||
return candidateValues;
|
||||
}
|
||||
|
||||
private static ArrayList<String> getBaseUrlCandidateValues() {
|
||||
return Lists.newArrayList(
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER,
|
||||
LocalAiModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getBaseUrlCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER,
|
||||
LocalAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, DEMO,
|
||||
QianfanModelFactory.PROVIDER, DEMO,
|
||||
ZhipuModelFactory.PROVIDER, DEMO,
|
||||
LocalAiModelFactory.PROVIDER, DEMO,
|
||||
AzureModelFactory.PROVIDER, DEMO,
|
||||
DashscopeModelFactory.PROVIDER, DEMO
|
||||
));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
|
||||
OllamaModelFactory.PROVIDER, "qwen:0.5b",
|
||||
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
|
||||
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GPT_4.toString(),
|
||||
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
|
||||
AzureModelFactory.PROVIDER, "gpt-35-turbo",
|
||||
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
AzureModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, "llama_2_70b"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(
|
||||
QianfanModelFactory.PROVIDER, DEMO
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,109 +21,34 @@ import java.util.List;
|
||||
@Service("EmbeddingModelParameterConfig")
|
||||
@Slf4j
|
||||
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
||||
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
|
||||
"接口协议", "",
|
||||
"list", "向量模型配置",
|
||||
getCandidateValues());
|
||||
"接口协议", "", "list",
|
||||
"向量模型配置", getCandidateValues());
|
||||
public static final Parameter EMBEDDING_MODEL_BASE_URL =
|
||||
new Parameter("s2.embedding.model.base.url", "",
|
||||
"BaseUrl", "",
|
||||
"string", "向量模型配置", null,
|
||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
|
||||
OllamaModelFactory.PROVIDER, "http://localhost:11434",
|
||||
AzureModelFactory.PROVIDER, "https://xxxx.openai.azure.com/",
|
||||
DashscopeModelFactory.PROVIDER, "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
QianfanModelFactory.PROVIDER, "https://aip.baidubce.com",
|
||||
ZhipuModelFactory.PROVIDER, "https://open.bigmodel.cn/api/paas/v4/"
|
||||
)
|
||||
)
|
||||
"BaseUrl", "", "string",
|
||||
"向量模型配置", null, getBaseUrlDependency()
|
||||
);
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_API_KEY =
|
||||
new Parameter("s2.embedding.model.api.key", "",
|
||||
"ApiKey", "",
|
||||
"string", "向量模型配置", null,
|
||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, "demo",
|
||||
OllamaModelFactory.PROVIDER, "demo",
|
||||
AzureModelFactory.PROVIDER, "demo",
|
||||
DashscopeModelFactory.PROVIDER, "demo",
|
||||
QianfanModelFactory.PROVIDER, "demo",
|
||||
ZhipuModelFactory.PROVIDER, "demo"
|
||||
)
|
||||
));
|
||||
|
||||
"ApiKey", "", "password",
|
||||
"向量模型配置", null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
"ModelName", "",
|
||||
"string", "向量模型配置", null,
|
||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
InMemoryModelFactory.PROVIDER,
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
OllamaModelFactory.PROVIDER, "all-minilm",
|
||||
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
DashscopeModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
QianfanModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
ZhipuModelFactory.PROVIDER, "text-embedding-ada-002"
|
||||
)
|
||||
));
|
||||
"ModelName", "", "string",
|
||||
"向量模型配置", null, getModelNameDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_PATH =
|
||||
new Parameter("s2.embedding.model.path", "",
|
||||
"模型路径", "",
|
||||
"string", "向量模型配置", null,
|
||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
InMemoryModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
InMemoryModelFactory.PROVIDER, ""
|
||||
)
|
||||
));
|
||||
|
||||
"模型路径", "", "string",
|
||||
"向量模型配置", null, getModelPathDependency());
|
||||
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
|
||||
new Parameter("s2.embedding.model.vocabulary.path", "",
|
||||
"词汇表路径", "",
|
||||
"string", "向量模型配置", null,
|
||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
InMemoryModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
InMemoryModelFactory.PROVIDER, ""
|
||||
)));
|
||||
"词汇表路径", "", "string",
|
||||
"向量模型配置", null, getModelPathDependency());
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
@@ -152,13 +77,80 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
}
|
||||
|
||||
private static ArrayList<String> getCandidateValues() {
|
||||
return Lists.newArrayList(InMemoryModelFactory.PROVIDER,
|
||||
return Lists.newArrayList(
|
||||
InMemoryModelFactory.PROVIDER,
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER);
|
||||
ZhipuModelFactory.PROVIDER
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO,
|
||||
OllamaModelFactory.PROVIDER, DEMO,
|
||||
AzureModelFactory.PROVIDER, DEMO,
|
||||
DashscopeModelFactory.PROVIDER, DEMO,
|
||||
QianfanModelFactory.PROVIDER, DEMO,
|
||||
ZhipuModelFactory.PROVIDER, DEMO)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
InMemoryModelFactory.PROVIDER,
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
OllamaModelFactory.PROVIDER, "all-minilm",
|
||||
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
DashscopeModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
QianfanModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
ZhipuModelFactory.PROVIDER, "text-embedding-ada-002"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelPathDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(InMemoryModelFactory.PROVIDER),
|
||||
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Service("EmbeddingStoreParameterConfig")
|
||||
@@ -16,50 +17,23 @@ import java.util.List;
|
||||
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
||||
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
|
||||
"向量库类型", "",
|
||||
"list", "向量库配置",
|
||||
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
|
||||
EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.CHROMA.name()));
|
||||
"向量库类型", "", "list",
|
||||
"向量库配置", getCandidateValues());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
||||
new Parameter("s2.embedding.store.base.url", "",
|
||||
"BaseUrl", "",
|
||||
"string", "向量库配置", null,
|
||||
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.CHROMA.name()
|
||||
),
|
||||
ImmutableMap.of(
|
||||
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
||||
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
|
||||
)
|
||||
));
|
||||
"BaseUrl", "", "string",
|
||||
"向量库配置", null, getBaseUrlDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_API_KEY =
|
||||
new Parameter("s2.embedding.store.api.key", "",
|
||||
"ApiKey", "",
|
||||
"string", "向量库配置", null,
|
||||
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
EmbeddingStoreType.MILVUS.name()
|
||||
),
|
||||
ImmutableMap.of(
|
||||
EmbeddingStoreType.MILVUS.name(), "demo"
|
||||
)
|
||||
));
|
||||
"ApiKey", "", "password",
|
||||
"向量库配置", null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
|
||||
new Parameter("s2.embedding.store.persist.path", "/tmp",
|
||||
"持久化路径", "",
|
||||
"string", "向量库配置", null,
|
||||
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
EmbeddingStoreType.IN_MEMORY.name()
|
||||
),
|
||||
ImmutableMap.of(
|
||||
EmbeddingStoreType.IN_MEMORY.name(), "/tmp"
|
||||
)));
|
||||
"持久化路径", "", "string",
|
||||
"向量库配置", null, getPathDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_TIMEOUT =
|
||||
new Parameter("s2.embedding.store.timeout", "60",
|
||||
@@ -68,16 +42,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_DIMENSION =
|
||||
new Parameter("s2.embedding.store.dimension", "",
|
||||
"纬度", "",
|
||||
"number", "向量库配置", null,
|
||||
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
EmbeddingStoreType.MILVUS.name()
|
||||
),
|
||||
ImmutableMap.of(
|
||||
EmbeddingStoreType.MILVUS.name(), "384"
|
||||
)
|
||||
));
|
||||
"纬度", "", "number",
|
||||
"向量库配置", null, getDimensionDependency());
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
@@ -97,13 +63,50 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
|
||||
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
|
||||
}
|
||||
return EmbeddingStoreConfig.builder()
|
||||
.provider(provider)
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.persistPath(persistPath)
|
||||
.timeOut(Long.valueOf(timeOut))
|
||||
.dimension(dimension)
|
||||
.build();
|
||||
return EmbeddingStoreConfig.builder().provider(provider)
|
||||
.baseUrl(baseUrl).apiKey(apiKey).persistPath(persistPath)
|
||||
.timeOut(Long.valueOf(timeOut)).dimension(dimension).build();
|
||||
}
|
||||
|
||||
private static ArrayList<String> getCandidateValues() {
|
||||
return Lists.newArrayList(
|
||||
EmbeddingStoreType.IN_MEMORY.name(),
|
||||
EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.CHROMA.name());
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()),
|
||||
ImmutableMap.of(
|
||||
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
||||
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getPathDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), "/tmp"));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getDimensionDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
EmbeddingStoreType.MILVUS.name()
|
||||
),
|
||||
ImmutableMap.of(
|
||||
EmbeddingStoreType.MILVUS.name(), "384"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import java.util.Map;
|
||||
|
||||
@Service
|
||||
public abstract class ParameterConfig {
|
||||
|
||||
public static final String DEMO = "demo";
|
||||
@Autowired
|
||||
private SystemConfigService sysConfigService;
|
||||
|
||||
|
||||
@@ -20,6 +20,12 @@ public class ChatModelConfig implements Serializable {
|
||||
private String modelName;
|
||||
private Double temperature = 0.0d;
|
||||
private Long timeOut = 60L;
|
||||
private String endpoint;
|
||||
private String secretKey;
|
||||
private Double topP;
|
||||
private Integer maxRetries = 3;
|
||||
private Boolean logRequests = false;
|
||||
private Boolean logResponses = false;
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||
|
||||
@@ -14,15 +14,19 @@ import java.time.Duration;
|
||||
@Service
|
||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "AZURE";
|
||||
public static final String DEFAULT_BASE_URL = "https://xxxx.openai.azure.com/";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(modelConfig.getBaseUrl())
|
||||
.endpoint(modelConfig.getEndpoint())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.deploymentName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()));
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.topP(modelConfig.getTopP())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
|
||||
.logRequestsAndResponses(modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import org.springframework.stereotype.Service;
|
||||
@Service
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "DASHSCOPE";
|
||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
@@ -21,6 +22,7 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature() == null ? 0L :
|
||||
modelConfig.getTemperature().floatValue())
|
||||
.topP(modelConfig.getTopP())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import java.time.Duration;
|
||||
@Service
|
||||
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "LOCAL_AI";
|
||||
|
||||
public static final String DEFAULT_BASE_URL = "http://localhost:8080";
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return LocalAiChatModel
|
||||
@@ -23,6 +23,10 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.topP(modelConfig.getTopP())
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -31,6 +35,9 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
return LocalAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModel.getBaseUrl())
|
||||
.modelName(embeddingModel.getModelName())
|
||||
.maxRetries(embeddingModel.getMaxRetries())
|
||||
.logRequests(embeddingModel.getLogRequests())
|
||||
.logResponses(embeddingModel.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "OLLAMA";
|
||||
public static final String DEFAULT_BASE_URL = "http://localhost:11434";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
@@ -23,6 +25,10 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "OPEN_AI";
|
||||
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
@@ -23,7 +25,11 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
.modelName(modelConfig.getModelName())
|
||||
.apiKey(modelConfig.keyDecrypt())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -12,10 +13,22 @@ import org.springframework.stereotype.Service;
|
||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "QIANFAN";
|
||||
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return null;
|
||||
return QianfanChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.secretKey(modelConfig.getSecretKey())
|
||||
.endpoint(modelConfig.getEndpoint())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -11,10 +12,20 @@ import org.springframework.stereotype.Service;
|
||||
@Service
|
||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "ZHIPU";
|
||||
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return null;
|
||||
return ZhipuAiChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.model(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user