mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) Add unit tests for each chatModel and embeddingModel. (#1582)
This commit is contained in:
@@ -4,8 +4,6 @@ import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
||||
import dev.langchain4j.model.zhipu.ChatCompletionModel;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||
@@ -16,7 +14,6 @@ import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Service("ChatModelParameterConfig")
|
||||
@@ -100,14 +97,9 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
}
|
||||
|
||||
private static List<String> getCandidateValues() {
|
||||
List<String> candidateValues = getBaseUrlCandidateValues();
|
||||
candidateValues.add(AzureModelFactory.PROVIDER);
|
||||
return candidateValues;
|
||||
}
|
||||
|
||||
private static ArrayList<String> getBaseUrlCandidateValues() {
|
||||
return Lists.newArrayList(
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER,
|
||||
@@ -117,9 +109,10 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getBaseUrlCandidateValues(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||
@@ -152,24 +145,21 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
|
||||
OllamaModelFactory.PROVIDER, "qwen:0.5b",
|
||||
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
|
||||
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GLM_4.toString(),
|
||||
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
|
||||
AzureModelFactory.PROVIDER, "gpt-35-turbo",
|
||||
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(AzureModelFactory.PROVIDER, QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, "llama_2_70b"
|
||||
)
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,11 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
"ApiKey", "", "password",
|
||||
"向量模型配置", null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
|
||||
new Parameter("s2.embedding.model.secretKey", "demo",
|
||||
"SecretKey", "", "password",
|
||||
"向量模型配置", null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
"ModelName", "", "string",
|
||||
@@ -54,7 +59,8 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY,
|
||||
EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH
|
||||
EMBEDDING_MODEL_SECRET_KEY, EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH,
|
||||
EMBEDDING_MODEL_VOCABULARY_PATH
|
||||
);
|
||||
}
|
||||
|
||||
@@ -65,11 +71,12 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
String modelName = getParameterValue(EMBEDDING_MODEL_NAME);
|
||||
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
|
||||
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
|
||||
|
||||
String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY);
|
||||
return EmbeddingModelConfig.builder()
|
||||
.provider(provider)
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.secretKey(secretKey)
|
||||
.modelName(modelName)
|
||||
.modelPath(modelPath)
|
||||
.vocabularyPath(vocabularyPath)
|
||||
@@ -135,12 +142,12 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
),
|
||||
ImmutableMap.of(
|
||||
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
OllamaModelFactory.PROVIDER, "all-minilm",
|
||||
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
|
||||
DashscopeModelFactory.PROVIDER, "text-embedding-v2",
|
||||
QianfanModelFactory.PROVIDER, "Embedding-V1",
|
||||
ZhipuModelFactory.PROVIDER, "embedding-2"
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -151,4 +158,11 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ public class EmbeddingModelConfig implements Serializable {
|
||||
private String provider;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String secretKey;
|
||||
private String modelName;
|
||||
private String modelPath;
|
||||
private String vocabularyPath;
|
||||
|
||||
@@ -14,12 +14,14 @@ import java.time.Duration;
|
||||
@Service
|
||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "AZURE";
|
||||
public static final String DEFAULT_BASE_URL = "https://xxxx.openai.azure.com/";
|
||||
public static final String DEFAULT_BASE_URL = "https://your-resource-name.openai.azure.com/";
|
||||
public static final String DEFAULT_MODEL_NAME = "gpt-35-turbo";
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(modelConfig.getEndpoint())
|
||||
.endpoint(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.deploymentName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -13,6 +14,8 @@ import org.springframework.stereotype.Service;
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "DASHSCOPE";
|
||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
|
||||
public static final String DEFAULT_MODEL_NAME = QwenModelName.QWEN_PLUS;
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-v2";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
|
||||
@@ -15,6 +15,7 @@ import java.time.Duration;
|
||||
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "LOCAL_AI";
|
||||
public static final String DEFAULT_BASE_URL = "http://localhost:8080";
|
||||
public static final String DEFAULT_MODEL_NAME = "ggml-gpt4all-j";
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return LocalAiChatModel
|
||||
|
||||
@@ -13,6 +13,7 @@ import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ModelProvider {
|
||||
|
||||
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
||||
|
||||
public static void add(String provider, ModelFactory modelFactory) {
|
||||
|
||||
@@ -16,6 +16,8 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "OLLAMA";
|
||||
public static final String DEFAULT_BASE_URL = "http://localhost:11434";
|
||||
public static final String DEFAULT_MODEL_NAME = "qwen:0.5b";
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "all-minilm";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
|
||||
@@ -16,6 +16,8 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "OPEN_AI";
|
||||
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
|
||||
public static final String DEFAULT_MODEL_NAME = "gpt-3.5-turbo";
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
|
||||
@@ -14,6 +14,10 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "QIANFAN";
|
||||
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
|
||||
public static final String DEFAULT_MODEL_NAME = "Llama-2-70b-chat";
|
||||
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "Embedding-V1";
|
||||
public static final String DEFAULT_ENDPOINT = "llama_2_70b";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
@@ -36,6 +40,7 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
return QianfanEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.secretKey(embeddingModelConfig.getSecretKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ChatCompletionModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
@@ -13,7 +14,8 @@ import org.springframework.stereotype.Service;
|
||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "ZHIPU";
|
||||
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
|
||||
|
||||
public static final String DEFAULT_MODEL_NAME = ChatCompletionModel.GLM_4.toString();
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "embedding-2";
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return ZhipuAiChatModel.builder()
|
||||
|
||||
@@ -3,14 +3,14 @@ package com.tencent.supersonic.evaluation;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.BaseTest;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import com.tencent.supersonic.util.LLMConfigUtils;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -105,7 +105,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
agentConfig.getTools().add(getLLMQueryTool());
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agent.setModelConfig(getLLMConfig(LLMType.GPT));
|
||||
agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.GPT));
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
||||
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
@@ -119,58 +119,4 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
return ruleQueryTool;
|
||||
}
|
||||
|
||||
private enum LLMType {
|
||||
GPT,
|
||||
MOONSHOT,
|
||||
DEEPSEEK,
|
||||
QWEN,
|
||||
GLM
|
||||
}
|
||||
|
||||
protected ChatModelConfig getLLMConfig(LLMType type) {
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
double temperature = 0.0;
|
||||
|
||||
switch (type) {
|
||||
case GLM:
|
||||
baseUrl = "https://open.bigmodel.cn/api/paas/v4/";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "glm-4";
|
||||
break;
|
||||
case MOONSHOT:
|
||||
baseUrl = "https://api.moonshot.cn/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "moonshot-v1-8k";
|
||||
break;
|
||||
case DEEPSEEK:
|
||||
baseUrl = "https://api.deepseek.com";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "deepseek-coder";
|
||||
break;
|
||||
case QWEN:
|
||||
baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "qwen-turbo";
|
||||
temperature = 0.01;
|
||||
break;
|
||||
case GPT:
|
||||
default:
|
||||
baseUrl = "https://api.openai.com/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "gpt-3.5-turbo";
|
||||
temperature = 0.0;
|
||||
}
|
||||
ChatModelConfig chatModel = new ChatModelConfig();
|
||||
chatModel.setModelName(modelName);
|
||||
chatModel.setBaseUrl(baseUrl);
|
||||
chatModel.setApiKey(apiKey);
|
||||
chatModel.setTemperature(temperature);
|
||||
chatModel.setProvider("open_ai");
|
||||
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
package com.tencent.supersonic.provider;
|
||||
|
||||
import com.tencent.supersonic.BaseApplication;
|
||||
import com.tencent.supersonic.common.config.ParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
@Disabled
|
||||
public class ModelProviderTest extends BaseApplication {
|
||||
|
||||
@Test
|
||||
public void test_openai_chat_model_with_openapi() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(OpenAiModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(OpenAiModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(OpenAiModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
String response = chatModel.generate("hi");
|
||||
assertNotNull(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_qianfan_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(QianfanModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setSecretKey(ParameterConfig.DEMO);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
modelConfig.setEndpoint(QianfanModelFactory.DEFAULT_ENDPOINT);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_zhipu_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_dashscope_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setEnableSearch(true);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_azure_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(AzureModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(AzureModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_in_memory_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(InMemoryModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(EmbeddingModelConstant.BGE_SMALL_ZH);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
Response<Embedding> embed = embeddingModel.embed("hi");
|
||||
assertNotNull(embed);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_openai_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(OpenAiModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(OpenAiModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
Response<Embedding> embed = embeddingModel.embed("hi");
|
||||
assertNotNull(embed);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_azure_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(AzureModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_dashscope_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_qianfan_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
modelConfig.setSecretKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_zhipu_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package com.tencent.supersonic.util;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
|
||||
public class LLMConfigUtils {
|
||||
public enum LLMType {
|
||||
GPT,
|
||||
MOONSHOT,
|
||||
DEEPSEEK,
|
||||
QWEN,
|
||||
GLM
|
||||
}
|
||||
|
||||
public static ChatModelConfig getLLMConfig(LLMType type) {
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
double temperature = 0.0;
|
||||
|
||||
switch (type) {
|
||||
case GLM:
|
||||
baseUrl = "https://open.bigmodel.cn/api/pas/v4/";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "glm-4";
|
||||
break;
|
||||
case MOONSHOT:
|
||||
baseUrl = "https://api.moonshot.cn/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "moonshot-v1-8k";
|
||||
break;
|
||||
case DEEPSEEK:
|
||||
baseUrl = "https://api.deepseek.com";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "deepseek-coder";
|
||||
break;
|
||||
case QWEN:
|
||||
baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "qwen-turbo";
|
||||
temperature = 0.01;
|
||||
break;
|
||||
case GPT:
|
||||
default:
|
||||
baseUrl = "https://api.openai.com/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "gpt-3.5-turbo";
|
||||
temperature = 0.0;
|
||||
}
|
||||
ChatModelConfig chatModel = new ChatModelConfig();
|
||||
chatModel.setModelName(modelName);
|
||||
chatModel.setBaseUrl(baseUrl);
|
||||
chatModel.setApiKey(apiKey);
|
||||
chatModel.setTemperature(temperature);
|
||||
chatModel.setProvider("open_ai");
|
||||
|
||||
return chatModel;
|
||||
}
|
||||
}
|
||||
29054
webapp/pnpm-lock.yaml
generated
29054
webapp/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user