(improvement)(chat) Add unit tests for each chatModel and embeddingModel. (#1582)

This commit is contained in:
lexluo09
2024-08-18 23:43:47 +08:00
committed by GitHub
parent 2801b27ade
commit 10a5e485cb
15 changed files with 13245 additions and 16198 deletions

View File

@@ -4,8 +4,6 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.model.dashscope.QwenModelName;
import dev.langchain4j.model.zhipu.ChatCompletionModel;
import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.LocalAiModelFactory;
@@ -16,7 +14,6 @@ import dev.langchain4j.provider.ZhipuModelFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
@Service("ChatModelParameterConfig")
@@ -100,14 +97,9 @@ public class ChatModelParameterConfig extends ParameterConfig {
}
private static List<String> getCandidateValues() {
List<String> candidateValues = getBaseUrlCandidateValues();
candidateValues.add(AzureModelFactory.PROVIDER);
return candidateValues;
}
private static ArrayList<String> getBaseUrlCandidateValues() {
return Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
@@ -117,9 +109,10 @@ public class ChatModelParameterConfig extends ParameterConfig {
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
getBaseUrlCandidateValues(),
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
@@ -152,24 +145,21 @@ public class ChatModelParameterConfig extends ParameterConfig {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
OllamaModelFactory.PROVIDER, "qwen:0.5b",
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GLM_4.toString(),
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
AzureModelFactory.PROVIDER, "gpt-35-turbo",
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME
)
);
}
private static List<Parameter.Dependency> getEndpointDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(AzureModelFactory.PROVIDER, QianfanModelFactory.PROVIDER),
ImmutableMap.of(
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, "llama_2_70b"
)
Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT)
);
}

View File

@@ -36,6 +36,11 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
"ApiKey", "", "password",
"向量模型配置", null, getApiKeyDependency());
public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
new Parameter("s2.embedding.model.secretKey", "demo",
"SecretKey", "", "password",
"向量模型配置", null, getSecretKeyDependency());
public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
"ModelName", "", "string",
@@ -54,7 +59,8 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY,
EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH
EMBEDDING_MODEL_SECRET_KEY, EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH,
EMBEDDING_MODEL_VOCABULARY_PATH
);
}
@@ -65,11 +71,12 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
String modelName = getParameterValue(EMBEDDING_MODEL_NAME);
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY);
return EmbeddingModelConfig.builder()
.provider(provider)
.baseUrl(baseUrl)
.apiKey(apiKey)
.secretKey(secretKey)
.modelName(modelName)
.modelPath(modelPath)
.vocabularyPath(vocabularyPath)
@@ -135,12 +142,12 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
OllamaModelFactory.PROVIDER, "all-minilm",
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
DashscopeModelFactory.PROVIDER, "text-embedding-v2",
QianfanModelFactory.PROVIDER, "Embedding-V1",
ZhipuModelFactory.PROVIDER, "embedding-2"
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME
)
);
}
@@ -151,4 +158,11 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")
);
}
private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
);
}
}

View File

@@ -16,6 +16,7 @@ public class EmbeddingModelConfig implements Serializable {
private String provider;
private String baseUrl;
private String apiKey;
private String secretKey;
private String modelName;
private String modelPath;
private String vocabularyPath;

View File

@@ -14,12 +14,14 @@ import java.time.Duration;
@Service
public class AzureModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "AZURE";
public static final String DEFAULT_BASE_URL = "https://xxxx.openai.azure.com/";
public static final String DEFAULT_BASE_URL = "https://your-resource-name.openai.azure.com/";
public static final String DEFAULT_MODEL_NAME = "gpt-35-turbo";
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getEndpoint())
.endpoint(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.dashscope.QwenChatModel;
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
import dev.langchain4j.model.dashscope.QwenModelName;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -13,6 +14,8 @@ import org.springframework.stereotype.Service;
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE";
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
public static final String DEFAULT_MODEL_NAME = QwenModelName.QWEN_PLUS;
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-v2";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {

View File

@@ -15,6 +15,7 @@ import java.time.Duration;
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "LOCAL_AI";
public static final String DEFAULT_BASE_URL = "http://localhost:8080";
public static final String DEFAULT_MODEL_NAME = "ggml-gpt4all-j";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return LocalAiChatModel

View File

@@ -13,6 +13,7 @@ import java.util.HashMap;
import java.util.Map;
public class ModelProvider {
private static final Map<String, ModelFactory> factories = new HashMap<>();
public static void add(String provider, ModelFactory modelFactory) {

View File

@@ -16,6 +16,8 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OLLAMA";
public static final String DEFAULT_BASE_URL = "http://localhost:11434";
public static final String DEFAULT_MODEL_NAME = "qwen:0.5b";
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "all-minilm";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {

View File

@@ -16,6 +16,8 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OPEN_AI";
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
public static final String DEFAULT_MODEL_NAME = "gpt-3.5-turbo";
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {

View File

@@ -14,6 +14,10 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "QIANFAN";
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
public static final String DEFAULT_MODEL_NAME = "Llama-2-70b-chat";
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "Embedding-V1";
public static final String DEFAULT_ENDPOINT = "llama_2_70b";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
@@ -36,6 +40,7 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
return QianfanEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.secretKey(embeddingModelConfig.getSecretKey())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.zhipu.ChatCompletionModel;
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
@@ -13,7 +14,8 @@ import org.springframework.stereotype.Service;
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU";
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
public static final String DEFAULT_MODEL_NAME = ChatCompletionModel.GLM_4.toString();
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "embedding-2";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return ZhipuAiChatModel.builder()