mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(chat) Support large models qianfan, zhipu, Azure, LocalAi, Dashscope, and handle the apiKey configuration as hidden. (#1552)
This commit is contained in:
@@ -4,8 +4,15 @@ import com.google.common.collect.ImmutableMap;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Parameter;
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
|
import dev.ai4j.openai4j.chat.ChatCompletionModel;
|
||||||
|
import dev.langchain4j.model.dashscope.QwenModelName;
|
||||||
|
import dev.langchain4j.provider.AzureModelFactory;
|
||||||
|
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||||
|
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||||
import dev.langchain4j.provider.OllamaModelFactory;
|
import dev.langchain4j.provider.OllamaModelFactory;
|
||||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
|
import dev.langchain4j.provider.QianfanModelFactory;
|
||||||
|
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@@ -18,49 +25,36 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
public static final Parameter CHAT_MODEL_PROVIDER =
|
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||||
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
|
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
|
||||||
"接口协议", "",
|
"接口协议", "", "list",
|
||||||
"list", "对话模型配置",
|
"对话模型配置", getCandidateValues());
|
||||||
getCandidateValues());
|
|
||||||
|
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_BASE_URL =
|
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||||
new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1",
|
new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
"BaseUrl", "", "string",
|
"BaseUrl", "", "string",
|
||||||
"对话模型配置", null,
|
"对话模型配置", null, getBaseUrlDependency());
|
||||||
getDependency(CHAT_MODEL_PROVIDER.getName(),
|
public static final Parameter CHAT_MODEL_ENDPOINT =
|
||||||
getCandidateValues(),
|
new Parameter("s2.chat.model.endpoint", "llama_2_70b",
|
||||||
ImmutableMap.of(
|
"Endpoint", "", "string",
|
||||||
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
|
"对话模型配置", null, getEndpointDependency());
|
||||||
OllamaModelFactory.PROVIDER, "http://localhost:11434")
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_API_KEY =
|
public static final Parameter CHAT_MODEL_API_KEY =
|
||||||
new Parameter("s2.chat.model.api.key", "demo",
|
new Parameter("s2.chat.model.api.key", DEMO,
|
||||||
"ApiKey", "",
|
"ApiKey", "", "password",
|
||||||
"string", "对话模型配置", null,
|
"对话模型配置", null, getApiKeyDependency()
|
||||||
getDependency(CHAT_MODEL_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "demo"))
|
|
||||||
);
|
);
|
||||||
|
public static final Parameter CHAT_MODEL_SECRET_KEY =
|
||||||
|
new Parameter("s2.chat.model.secretKey", "demo",
|
||||||
|
"SecretKey", "", "password",
|
||||||
|
"对话模型配置", null, getSecretKeyDependency());
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_NAME =
|
public static final Parameter CHAT_MODEL_NAME =
|
||||||
new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
|
new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
|
||||||
"ModelName", "",
|
"ModelName", "", "string",
|
||||||
"string", "对话模型配置", null,
|
"对话模型配置", null, getModelNameDependency());
|
||||||
getDependency(CHAT_MODEL_PROVIDER.getName(),
|
|
||||||
getCandidateValues(),
|
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
|
|
||||||
OllamaModelFactory.PROVIDER, "qwen:0.5b")
|
|
||||||
));
|
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||||
new Parameter("s2.chat.model.temperature", "0.0",
|
new Parameter("s2.chat.model.temperature", "0.0",
|
||||||
"Temperature", "",
|
"Temperature", "",
|
||||||
"slider", "对话模型配置", null,
|
"slider", "对话模型配置");
|
||||||
getDependency(CHAT_MODEL_PROVIDER.getName(),
|
|
||||||
getCandidateValues(),
|
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "0.0", OllamaModelFactory.PROVIDER, "0.0")));
|
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_TIMEOUT =
|
public static final Parameter CHAT_MODEL_TIMEOUT =
|
||||||
new Parameter("s2.chat.model.timeout", "60",
|
new Parameter("s2.chat.model.timeout", "60",
|
||||||
@@ -70,8 +64,9 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
@Override
|
@Override
|
||||||
public List<Parameter> getSysParameters() {
|
public List<Parameter> getSysParameters() {
|
||||||
return Lists.newArrayList(
|
return Lists.newArrayList(
|
||||||
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY,
|
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||||
CHAT_MODEL_NAME, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
||||||
|
CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,6 +77,8 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
|
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
|
||||||
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
|
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
|
||||||
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
||||||
|
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
|
||||||
|
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
|
||||||
|
|
||||||
return ChatModelConfig.builder()
|
return ChatModelConfig.builder()
|
||||||
.provider(chatModelProvider)
|
.provider(chatModelProvider)
|
||||||
@@ -90,10 +87,94 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
.modelName(chatModelName)
|
.modelName(chatModelName)
|
||||||
.temperature(Double.valueOf(chatModelTemperature))
|
.temperature(Double.valueOf(chatModelTemperature))
|
||||||
.timeOut(Long.valueOf(chatModelTimeout))
|
.timeOut(Long.valueOf(chatModelTimeout))
|
||||||
|
.endpoint(endpoint)
|
||||||
|
.secretKey(secretKey)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ArrayList<String> getCandidateValues() {
|
private static List<String> getCandidateValues() {
|
||||||
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER);
|
List<String> candidateValues = getBaseUrlCandidateValues();
|
||||||
|
candidateValues.add(AzureModelFactory.PROVIDER);
|
||||||
|
return candidateValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ArrayList<String> getBaseUrlCandidateValues() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
OpenAiModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER,
|
||||||
|
LocalAiModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
getBaseUrlCandidateValues(),
|
||||||
|
ImmutableMap.of(
|
||||||
|
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||||
|
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||||
|
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||||
|
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
||||||
|
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(
|
||||||
|
OpenAiModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER,
|
||||||
|
LocalAiModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER
|
||||||
|
),
|
||||||
|
ImmutableMap.of(
|
||||||
|
OpenAiModelFactory.PROVIDER, DEMO,
|
||||||
|
QianfanModelFactory.PROVIDER, DEMO,
|
||||||
|
ZhipuModelFactory.PROVIDER, DEMO,
|
||||||
|
LocalAiModelFactory.PROVIDER, DEMO,
|
||||||
|
AzureModelFactory.PROVIDER, DEMO,
|
||||||
|
DashscopeModelFactory.PROVIDER, DEMO
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
getCandidateValues(),
|
||||||
|
ImmutableMap.of(
|
||||||
|
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
|
||||||
|
OllamaModelFactory.PROVIDER, "qwen:0.5b",
|
||||||
|
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
|
||||||
|
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GPT_4.toString(),
|
||||||
|
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
|
||||||
|
AzureModelFactory.PROVIDER, "gpt-35-turbo",
|
||||||
|
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER
|
||||||
|
),
|
||||||
|
ImmutableMap.of(
|
||||||
|
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||||
|
QianfanModelFactory.PROVIDER, "llama_2_70b"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(
|
||||||
|
QianfanModelFactory.PROVIDER, DEMO
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,109 +21,34 @@ import java.util.List;
|
|||||||
@Service("EmbeddingModelParameterConfig")
|
@Service("EmbeddingModelParameterConfig")
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
||||||
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
|
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
|
||||||
"接口协议", "",
|
"接口协议", "", "list",
|
||||||
"list", "向量模型配置",
|
"向量模型配置", getCandidateValues());
|
||||||
getCandidateValues());
|
|
||||||
public static final Parameter EMBEDDING_MODEL_BASE_URL =
|
public static final Parameter EMBEDDING_MODEL_BASE_URL =
|
||||||
new Parameter("s2.embedding.model.base.url", "",
|
new Parameter("s2.embedding.model.base.url", "",
|
||||||
"BaseUrl", "",
|
"BaseUrl", "", "string",
|
||||||
"string", "向量模型配置", null,
|
"向量模型配置", null, getBaseUrlDependency()
|
||||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
OpenAiModelFactory.PROVIDER,
|
|
||||||
OllamaModelFactory.PROVIDER,
|
|
||||||
AzureModelFactory.PROVIDER,
|
|
||||||
DashscopeModelFactory.PROVIDER,
|
|
||||||
QianfanModelFactory.PROVIDER,
|
|
||||||
ZhipuModelFactory.PROVIDER
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
|
|
||||||
OllamaModelFactory.PROVIDER, "http://localhost:11434",
|
|
||||||
AzureModelFactory.PROVIDER, "https://xxxx.openai.azure.com/",
|
|
||||||
DashscopeModelFactory.PROVIDER, "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
||||||
QianfanModelFactory.PROVIDER, "https://aip.baidubce.com",
|
|
||||||
ZhipuModelFactory.PROVIDER, "https://open.bigmodel.cn/api/paas/v4/"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MODEL_API_KEY =
|
public static final Parameter EMBEDDING_MODEL_API_KEY =
|
||||||
new Parameter("s2.embedding.model.api.key", "",
|
new Parameter("s2.embedding.model.api.key", "",
|
||||||
"ApiKey", "",
|
"ApiKey", "", "password",
|
||||||
"string", "向量模型配置", null,
|
"向量模型配置", null, getApiKeyDependency());
|
||||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
OpenAiModelFactory.PROVIDER,
|
|
||||||
OllamaModelFactory.PROVIDER,
|
|
||||||
AzureModelFactory.PROVIDER,
|
|
||||||
DashscopeModelFactory.PROVIDER,
|
|
||||||
QianfanModelFactory.PROVIDER,
|
|
||||||
ZhipuModelFactory.PROVIDER
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
OpenAiModelFactory.PROVIDER, "demo",
|
|
||||||
OllamaModelFactory.PROVIDER, "demo",
|
|
||||||
AzureModelFactory.PROVIDER, "demo",
|
|
||||||
DashscopeModelFactory.PROVIDER, "demo",
|
|
||||||
QianfanModelFactory.PROVIDER, "demo",
|
|
||||||
ZhipuModelFactory.PROVIDER, "demo"
|
|
||||||
)
|
|
||||||
));
|
|
||||||
|
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MODEL_NAME =
|
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||||
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||||
"ModelName", "",
|
"ModelName", "", "string",
|
||||||
"string", "向量模型配置", null,
|
"向量模型配置", null, getModelNameDependency());
|
||||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
InMemoryModelFactory.PROVIDER,
|
|
||||||
OpenAiModelFactory.PROVIDER,
|
|
||||||
OllamaModelFactory.PROVIDER,
|
|
||||||
AzureModelFactory.PROVIDER,
|
|
||||||
DashscopeModelFactory.PROVIDER,
|
|
||||||
QianfanModelFactory.PROVIDER,
|
|
||||||
ZhipuModelFactory.PROVIDER
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
|
||||||
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
|
|
||||||
OllamaModelFactory.PROVIDER, "all-minilm",
|
|
||||||
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
|
|
||||||
DashscopeModelFactory.PROVIDER, "text-embedding-ada-002",
|
|
||||||
QianfanModelFactory.PROVIDER, "text-embedding-ada-002",
|
|
||||||
ZhipuModelFactory.PROVIDER, "text-embedding-ada-002"
|
|
||||||
)
|
|
||||||
));
|
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MODEL_PATH =
|
public static final Parameter EMBEDDING_MODEL_PATH =
|
||||||
new Parameter("s2.embedding.model.path", "",
|
new Parameter("s2.embedding.model.path", "",
|
||||||
"模型路径", "",
|
"模型路径", "", "string",
|
||||||
"string", "向量模型配置", null,
|
"向量模型配置", null, getModelPathDependency());
|
||||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
InMemoryModelFactory.PROVIDER
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
InMemoryModelFactory.PROVIDER, ""
|
|
||||||
)
|
|
||||||
));
|
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
|
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
|
||||||
new Parameter("s2.embedding.model.vocabulary.path", "",
|
new Parameter("s2.embedding.model.vocabulary.path", "",
|
||||||
"词汇表路径", "",
|
"词汇表路径", "", "string",
|
||||||
"string", "向量模型配置", null,
|
"向量模型配置", null, getModelPathDependency());
|
||||||
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
InMemoryModelFactory.PROVIDER
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
InMemoryModelFactory.PROVIDER, ""
|
|
||||||
)));
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Parameter> getSysParameters() {
|
public List<Parameter> getSysParameters() {
|
||||||
@@ -152,13 +77,80 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static ArrayList<String> getCandidateValues() {
|
private static ArrayList<String> getCandidateValues() {
|
||||||
return Lists.newArrayList(InMemoryModelFactory.PROVIDER,
|
return Lists.newArrayList(
|
||||||
|
InMemoryModelFactory.PROVIDER,
|
||||||
OpenAiModelFactory.PROVIDER,
|
OpenAiModelFactory.PROVIDER,
|
||||||
OllamaModelFactory.PROVIDER,
|
OllamaModelFactory.PROVIDER,
|
||||||
AzureModelFactory.PROVIDER,
|
AzureModelFactory.PROVIDER,
|
||||||
DashscopeModelFactory.PROVIDER,
|
DashscopeModelFactory.PROVIDER,
|
||||||
QianfanModelFactory.PROVIDER,
|
QianfanModelFactory.PROVIDER,
|
||||||
ZhipuModelFactory.PROVIDER);
|
ZhipuModelFactory.PROVIDER
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(
|
||||||
|
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||||
|
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||||
|
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||||
|
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||||
|
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO,
|
||||||
|
OllamaModelFactory.PROVIDER, DEMO,
|
||||||
|
AzureModelFactory.PROVIDER, DEMO,
|
||||||
|
DashscopeModelFactory.PROVIDER, DEMO,
|
||||||
|
QianfanModelFactory.PROVIDER, DEMO,
|
||||||
|
ZhipuModelFactory.PROVIDER, DEMO)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(
|
||||||
|
InMemoryModelFactory.PROVIDER,
|
||||||
|
OpenAiModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER
|
||||||
|
),
|
||||||
|
ImmutableMap.of(
|
||||||
|
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||||
|
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||||
|
OllamaModelFactory.PROVIDER, "all-minilm",
|
||||||
|
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||||
|
DashscopeModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||||
|
QianfanModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||||
|
ZhipuModelFactory.PROVIDER, "text-embedding-ada-002"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getModelPathDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(InMemoryModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Service("EmbeddingStoreParameterConfig")
|
@Service("EmbeddingStoreParameterConfig")
|
||||||
@@ -16,50 +17,23 @@ import java.util.List;
|
|||||||
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||||
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
||||||
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
|
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
|
||||||
"向量库类型", "",
|
"向量库类型", "", "list",
|
||||||
"list", "向量库配置",
|
"向量库配置", getCandidateValues());
|
||||||
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
|
|
||||||
EmbeddingStoreType.MILVUS.name(),
|
|
||||||
EmbeddingStoreType.CHROMA.name()));
|
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
||||||
new Parameter("s2.embedding.store.base.url", "",
|
new Parameter("s2.embedding.store.base.url", "",
|
||||||
"BaseUrl", "",
|
"BaseUrl", "", "string",
|
||||||
"string", "向量库配置", null,
|
"向量库配置", null, getBaseUrlDependency());
|
||||||
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
EmbeddingStoreType.MILVUS.name(),
|
|
||||||
EmbeddingStoreType.CHROMA.name()
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
|
||||||
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
|
|
||||||
)
|
|
||||||
));
|
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_STORE_API_KEY =
|
public static final Parameter EMBEDDING_STORE_API_KEY =
|
||||||
new Parameter("s2.embedding.store.api.key", "",
|
new Parameter("s2.embedding.store.api.key", "",
|
||||||
"ApiKey", "",
|
"ApiKey", "", "password",
|
||||||
"string", "向量库配置", null,
|
"向量库配置", null, getApiKeyDependency());
|
||||||
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
EmbeddingStoreType.MILVUS.name()
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
EmbeddingStoreType.MILVUS.name(), "demo"
|
|
||||||
)
|
|
||||||
));
|
|
||||||
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
|
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
|
||||||
new Parameter("s2.embedding.store.persist.path", "/tmp",
|
new Parameter("s2.embedding.store.persist.path", "/tmp",
|
||||||
"持久化路径", "",
|
"持久化路径", "", "string",
|
||||||
"string", "向量库配置", null,
|
"向量库配置", null, getPathDependency());
|
||||||
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
EmbeddingStoreType.IN_MEMORY.name()
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
EmbeddingStoreType.IN_MEMORY.name(), "/tmp"
|
|
||||||
)));
|
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_STORE_TIMEOUT =
|
public static final Parameter EMBEDDING_STORE_TIMEOUT =
|
||||||
new Parameter("s2.embedding.store.timeout", "60",
|
new Parameter("s2.embedding.store.timeout", "60",
|
||||||
@@ -68,16 +42,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
public static final Parameter EMBEDDING_STORE_DIMENSION =
|
public static final Parameter EMBEDDING_STORE_DIMENSION =
|
||||||
new Parameter("s2.embedding.store.dimension", "",
|
new Parameter("s2.embedding.store.dimension", "",
|
||||||
"纬度", "",
|
"纬度", "", "number",
|
||||||
"number", "向量库配置", null,
|
"向量库配置", null, getDimensionDependency());
|
||||||
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(
|
|
||||||
EmbeddingStoreType.MILVUS.name()
|
|
||||||
),
|
|
||||||
ImmutableMap.of(
|
|
||||||
EmbeddingStoreType.MILVUS.name(), "384"
|
|
||||||
)
|
|
||||||
));
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Parameter> getSysParameters() {
|
public List<Parameter> getSysParameters() {
|
||||||
@@ -97,13 +63,50 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
|
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
|
||||||
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
|
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
|
||||||
}
|
}
|
||||||
return EmbeddingStoreConfig.builder()
|
return EmbeddingStoreConfig.builder().provider(provider)
|
||||||
.provider(provider)
|
.baseUrl(baseUrl).apiKey(apiKey).persistPath(persistPath)
|
||||||
.baseUrl(baseUrl)
|
.timeOut(Long.valueOf(timeOut)).dimension(dimension).build();
|
||||||
.apiKey(apiKey)
|
}
|
||||||
.persistPath(persistPath)
|
|
||||||
.timeOut(Long.valueOf(timeOut))
|
private static ArrayList<String> getCandidateValues() {
|
||||||
.dimension(dimension)
|
return Lists.newArrayList(
|
||||||
.build();
|
EmbeddingStoreType.IN_MEMORY.name(),
|
||||||
|
EmbeddingStoreType.MILVUS.name(),
|
||||||
|
EmbeddingStoreType.CHROMA.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(
|
||||||
|
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()),
|
||||||
|
ImmutableMap.of(
|
||||||
|
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
||||||
|
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||||
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getPathDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()),
|
||||||
|
ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), "/tmp"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getDimensionDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(
|
||||||
|
EmbeddingStoreType.MILVUS.name()
|
||||||
|
),
|
||||||
|
ImmutableMap.of(
|
||||||
|
EmbeddingStoreType.MILVUS.name(), "384"
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -14,7 +14,7 @@ import java.util.Map;
|
|||||||
|
|
||||||
@Service
|
@Service
|
||||||
public abstract class ParameterConfig {
|
public abstract class ParameterConfig {
|
||||||
|
public static final String DEMO = "demo";
|
||||||
@Autowired
|
@Autowired
|
||||||
private SystemConfigService sysConfigService;
|
private SystemConfigService sysConfigService;
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,12 @@ public class ChatModelConfig implements Serializable {
|
|||||||
private String modelName;
|
private String modelName;
|
||||||
private Double temperature = 0.0d;
|
private Double temperature = 0.0d;
|
||||||
private Long timeOut = 60L;
|
private Long timeOut = 60L;
|
||||||
|
private String endpoint;
|
||||||
|
private String secretKey;
|
||||||
|
private Double topP;
|
||||||
|
private Integer maxRetries = 3;
|
||||||
|
private Boolean logRequests = false;
|
||||||
|
private Boolean logResponses = false;
|
||||||
|
|
||||||
public String keyDecrypt() {
|
public String keyDecrypt() {
|
||||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||||
|
|||||||
@@ -14,15 +14,19 @@ import java.time.Duration;
|
|||||||
@Service
|
@Service
|
||||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||||
public static final String PROVIDER = "AZURE";
|
public static final String PROVIDER = "AZURE";
|
||||||
|
public static final String DEFAULT_BASE_URL = "https://xxxx.openai.azure.com/";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||||
.endpoint(modelConfig.getBaseUrl())
|
.endpoint(modelConfig.getEndpoint())
|
||||||
.apiKey(modelConfig.getApiKey())
|
.apiKey(modelConfig.getApiKey())
|
||||||
.deploymentName(modelConfig.getModelName())
|
.deploymentName(modelConfig.getModelName())
|
||||||
.temperature(modelConfig.getTemperature())
|
.temperature(modelConfig.getTemperature())
|
||||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()));
|
.maxRetries(modelConfig.getMaxRetries())
|
||||||
|
.topP(modelConfig.getTopP())
|
||||||
|
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
|
||||||
|
.logRequestsAndResponses(modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
|
||||||
return builder.build();
|
return builder.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import org.springframework.stereotype.Service;
|
|||||||
@Service
|
@Service
|
||||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||||
public static final String PROVIDER = "DASHSCOPE";
|
public static final String PROVIDER = "DASHSCOPE";
|
||||||
|
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
@@ -21,6 +22,7 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
|||||||
.modelName(modelConfig.getModelName())
|
.modelName(modelConfig.getModelName())
|
||||||
.temperature(modelConfig.getTemperature() == null ? 0L :
|
.temperature(modelConfig.getTemperature() == null ? 0L :
|
||||||
modelConfig.getTemperature().floatValue())
|
modelConfig.getTemperature().floatValue())
|
||||||
|
.topP(modelConfig.getTopP())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import java.time.Duration;
|
|||||||
@Service
|
@Service
|
||||||
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||||
public static final String PROVIDER = "LOCAL_AI";
|
public static final String PROVIDER = "LOCAL_AI";
|
||||||
|
public static final String DEFAULT_BASE_URL = "http://localhost:8080";
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
return LocalAiChatModel
|
return LocalAiChatModel
|
||||||
@@ -23,6 +23,10 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
|||||||
.modelName(modelConfig.getModelName())
|
.modelName(modelConfig.getModelName())
|
||||||
.temperature(modelConfig.getTemperature())
|
.temperature(modelConfig.getTemperature())
|
||||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||||
|
.topP(modelConfig.getTopP())
|
||||||
|
.logRequests(modelConfig.getLogRequests())
|
||||||
|
.logResponses(modelConfig.getLogResponses())
|
||||||
|
.maxRetries(modelConfig.getMaxRetries())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,6 +35,9 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
|||||||
return LocalAiEmbeddingModel.builder()
|
return LocalAiEmbeddingModel.builder()
|
||||||
.baseUrl(embeddingModel.getBaseUrl())
|
.baseUrl(embeddingModel.getBaseUrl())
|
||||||
.modelName(embeddingModel.getModelName())
|
.modelName(embeddingModel.getModelName())
|
||||||
|
.maxRetries(embeddingModel.getMaxRetries())
|
||||||
|
.logRequests(embeddingModel.getLogRequests())
|
||||||
|
.logResponses(embeddingModel.getLogResponses())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ import java.time.Duration;
|
|||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||||
|
|
||||||
public static final String PROVIDER = "OLLAMA";
|
public static final String PROVIDER = "OLLAMA";
|
||||||
|
public static final String DEFAULT_BASE_URL = "http://localhost:11434";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
@@ -23,6 +25,10 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
|||||||
.modelName(modelConfig.getModelName())
|
.modelName(modelConfig.getModelName())
|
||||||
.temperature(modelConfig.getTemperature())
|
.temperature(modelConfig.getTemperature())
|
||||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||||
|
.topP(modelConfig.getTopP())
|
||||||
|
.maxRetries(modelConfig.getMaxRetries())
|
||||||
|
.logRequests(modelConfig.getLogRequests())
|
||||||
|
.logResponses(modelConfig.getLogResponses())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ import java.time.Duration;
|
|||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||||
|
|
||||||
public static final String PROVIDER = "OPEN_AI";
|
public static final String PROVIDER = "OPEN_AI";
|
||||||
|
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
@@ -23,7 +25,11 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
|||||||
.modelName(modelConfig.getModelName())
|
.modelName(modelConfig.getModelName())
|
||||||
.apiKey(modelConfig.keyDecrypt())
|
.apiKey(modelConfig.keyDecrypt())
|
||||||
.temperature(modelConfig.getTemperature())
|
.temperature(modelConfig.getTemperature())
|
||||||
|
.topP(modelConfig.getTopP())
|
||||||
|
.maxRetries(modelConfig.getMaxRetries())
|
||||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||||
|
.logRequests(modelConfig.getLogRequests())
|
||||||
|
.logResponses(modelConfig.getLogResponses())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
||||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
import org.springframework.beans.factory.InitializingBean;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -12,10 +13,22 @@ import org.springframework.stereotype.Service;
|
|||||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||||
|
|
||||||
public static final String PROVIDER = "QIANFAN";
|
public static final String PROVIDER = "QIANFAN";
|
||||||
|
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
return null;
|
return QianfanChatModel.builder()
|
||||||
|
.baseUrl(modelConfig.getBaseUrl())
|
||||||
|
.apiKey(modelConfig.getApiKey())
|
||||||
|
.secretKey(modelConfig.getSecretKey())
|
||||||
|
.endpoint(modelConfig.getEndpoint())
|
||||||
|
.modelName(modelConfig.getModelName())
|
||||||
|
.temperature(modelConfig.getTemperature())
|
||||||
|
.topP(modelConfig.getTopP())
|
||||||
|
.maxRetries(modelConfig.getMaxRetries())
|
||||||
|
.logRequests(modelConfig.getLogRequests())
|
||||||
|
.logResponses(modelConfig.getLogResponses())
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
import org.springframework.beans.factory.InitializingBean;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -11,10 +12,20 @@ import org.springframework.stereotype.Service;
|
|||||||
@Service
|
@Service
|
||||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||||
public static final String PROVIDER = "ZHIPU";
|
public static final String PROVIDER = "ZHIPU";
|
||||||
|
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
return null;
|
return ZhipuAiChatModel.builder()
|
||||||
|
.baseUrl(modelConfig.getBaseUrl())
|
||||||
|
.apiKey(modelConfig.getApiKey())
|
||||||
|
.model(modelConfig.getModelName())
|
||||||
|
.temperature(modelConfig.getTemperature())
|
||||||
|
.topP(modelConfig.getTopP())
|
||||||
|
.maxRetries(modelConfig.getMaxRetries())
|
||||||
|
.logRequests(modelConfig.getLogRequests())
|
||||||
|
.logResponses(modelConfig.getLogResponses())
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
Reference in New Issue
Block a user