mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-18 08:17:18 +00:00
(improvement)(chat) Add unit tests for each chatModel and embeddingModel. (#1582)
This commit is contained in:
@@ -4,8 +4,6 @@ import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
||||
import dev.langchain4j.model.zhipu.ChatCompletionModel;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||
@@ -16,7 +14,6 @@ import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Service("ChatModelParameterConfig")
|
||||
@@ -100,14 +97,9 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
}
|
||||
|
||||
private static List<String> getCandidateValues() {
|
||||
List<String> candidateValues = getBaseUrlCandidateValues();
|
||||
candidateValues.add(AzureModelFactory.PROVIDER);
|
||||
return candidateValues;
|
||||
}
|
||||
|
||||
private static ArrayList<String> getBaseUrlCandidateValues() {
|
||||
return Lists.newArrayList(
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER,
|
||||
@@ -117,9 +109,10 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getBaseUrlCandidateValues(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||
@@ -152,24 +145,21 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
|
||||
OllamaModelFactory.PROVIDER, "qwen:0.5b",
|
||||
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
|
||||
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GLM_4.toString(),
|
||||
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
|
||||
AzureModelFactory.PROVIDER, "gpt-35-turbo",
|
||||
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(AzureModelFactory.PROVIDER, QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, "llama_2_70b"
|
||||
)
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,11 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
"ApiKey", "", "password",
|
||||
"向量模型配置", null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
|
||||
new Parameter("s2.embedding.model.secretKey", "demo",
|
||||
"SecretKey", "", "password",
|
||||
"向量模型配置", null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
"ModelName", "", "string",
|
||||
@@ -54,7 +59,8 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY,
|
||||
EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH
|
||||
EMBEDDING_MODEL_SECRET_KEY, EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH,
|
||||
EMBEDDING_MODEL_VOCABULARY_PATH
|
||||
);
|
||||
}
|
||||
|
||||
@@ -65,11 +71,12 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
String modelName = getParameterValue(EMBEDDING_MODEL_NAME);
|
||||
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
|
||||
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
|
||||
|
||||
String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY);
|
||||
return EmbeddingModelConfig.builder()
|
||||
.provider(provider)
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.secretKey(secretKey)
|
||||
.modelName(modelName)
|
||||
.modelPath(modelPath)
|
||||
.vocabularyPath(vocabularyPath)
|
||||
@@ -135,12 +142,12 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
),
|
||||
ImmutableMap.of(
|
||||
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
OllamaModelFactory.PROVIDER, "all-minilm",
|
||||
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
DashscopeModelFactory.PROVIDER, "text-embedding-v2",
|
||||
QianfanModelFactory.PROVIDER, "Embedding-V1",
|
||||
ZhipuModelFactory.PROVIDER, "embedding-2"
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -151,4 +158,11 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ public class EmbeddingModelConfig implements Serializable {
|
||||
private String provider;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String secretKey;
|
||||
private String modelName;
|
||||
private String modelPath;
|
||||
private String vocabularyPath;
|
||||
|
||||
Reference in New Issue
Block a user