mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(chat) Add unit tests for each chatModel and embeddingModel. (#1582)
This commit is contained in:
@@ -3,14 +3,14 @@ package com.tencent.supersonic.evaluation;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.BaseTest;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import com.tencent.supersonic.util.LLMConfigUtils;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -105,7 +105,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
agentConfig.getTools().add(getLLMQueryTool());
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agent.setModelConfig(getLLMConfig(LLMType.GPT));
|
||||
agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.GPT));
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
||||
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
@@ -119,58 +119,4 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
return ruleQueryTool;
|
||||
}
|
||||
|
||||
private enum LLMType {
|
||||
GPT,
|
||||
MOONSHOT,
|
||||
DEEPSEEK,
|
||||
QWEN,
|
||||
GLM
|
||||
}
|
||||
|
||||
protected ChatModelConfig getLLMConfig(LLMType type) {
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
double temperature = 0.0;
|
||||
|
||||
switch (type) {
|
||||
case GLM:
|
||||
baseUrl = "https://open.bigmodel.cn/api/paas/v4/";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "glm-4";
|
||||
break;
|
||||
case MOONSHOT:
|
||||
baseUrl = "https://api.moonshot.cn/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "moonshot-v1-8k";
|
||||
break;
|
||||
case DEEPSEEK:
|
||||
baseUrl = "https://api.deepseek.com";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "deepseek-coder";
|
||||
break;
|
||||
case QWEN:
|
||||
baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "qwen-turbo";
|
||||
temperature = 0.01;
|
||||
break;
|
||||
case GPT:
|
||||
default:
|
||||
baseUrl = "https://api.openai.com/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "gpt-3.5-turbo";
|
||||
temperature = 0.0;
|
||||
}
|
||||
ChatModelConfig chatModel = new ChatModelConfig();
|
||||
chatModel.setModelName(modelName);
|
||||
chatModel.setBaseUrl(baseUrl);
|
||||
chatModel.setApiKey(apiKey);
|
||||
chatModel.setTemperature(temperature);
|
||||
chatModel.setProvider("open_ai");
|
||||
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
package com.tencent.supersonic.provider;
|
||||
|
||||
import com.tencent.supersonic.BaseApplication;
|
||||
import com.tencent.supersonic.common.config.ParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
@Disabled
|
||||
public class ModelProviderTest extends BaseApplication {
|
||||
|
||||
@Test
|
||||
public void test_openai_chat_model_with_openapi() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(OpenAiModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(OpenAiModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(OpenAiModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
String response = chatModel.generate("hi");
|
||||
assertNotNull(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_qianfan_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(QianfanModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setSecretKey(ParameterConfig.DEMO);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
modelConfig.setEndpoint(QianfanModelFactory.DEFAULT_ENDPOINT);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_zhipu_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_dashscope_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setEnableSearch(true);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_azure_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(AzureModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(AzureModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_in_memory_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(InMemoryModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(EmbeddingModelConstant.BGE_SMALL_ZH);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
Response<Embedding> embed = embeddingModel.embed("hi");
|
||||
assertNotNull(embed);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_openai_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(OpenAiModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(OpenAiModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
Response<Embedding> embed = embeddingModel.embed("hi");
|
||||
assertNotNull(embed);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_azure_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(AzureModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_dashscope_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_qianfan_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
modelConfig.setSecretKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_zhipu_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package com.tencent.supersonic.util;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
|
||||
public class LLMConfigUtils {
|
||||
public enum LLMType {
|
||||
GPT,
|
||||
MOONSHOT,
|
||||
DEEPSEEK,
|
||||
QWEN,
|
||||
GLM
|
||||
}
|
||||
|
||||
public static ChatModelConfig getLLMConfig(LLMType type) {
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
double temperature = 0.0;
|
||||
|
||||
switch (type) {
|
||||
case GLM:
|
||||
baseUrl = "https://open.bigmodel.cn/api/pas/v4/";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "glm-4";
|
||||
break;
|
||||
case MOONSHOT:
|
||||
baseUrl = "https://api.moonshot.cn/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "moonshot-v1-8k";
|
||||
break;
|
||||
case DEEPSEEK:
|
||||
baseUrl = "https://api.deepseek.com";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "deepseek-coder";
|
||||
break;
|
||||
case QWEN:
|
||||
baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "qwen-turbo";
|
||||
temperature = 0.01;
|
||||
break;
|
||||
case GPT:
|
||||
default:
|
||||
baseUrl = "https://api.openai.com/v1";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
modelName = "gpt-3.5-turbo";
|
||||
temperature = 0.0;
|
||||
}
|
||||
ChatModelConfig chatModel = new ChatModelConfig();
|
||||
chatModel.setModelName(modelName);
|
||||
chatModel.setBaseUrl(baseUrl);
|
||||
chatModel.setApiKey(apiKey);
|
||||
chatModel.setTemperature(temperature);
|
||||
chatModel.setProvider("open_ai");
|
||||
|
||||
return chatModel;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user