(improvement)(chat) Add unit tests for each chatModel and embeddingModel. (#1582)

This commit is contained in:
lexluo09
2024-08-18 23:43:47 +08:00
committed by GitHub
parent 2801b27ade
commit 10a5e485cb
15 changed files with 13245 additions and 16198 deletions

View File

@@ -3,14 +3,14 @@ package com.tencent.supersonic.evaluation;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.BaseTest;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.util.DataUtils;
import com.tencent.supersonic.util.LLMConfigUtils;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
@@ -105,7 +105,7 @@ public class Text2SQLEval extends BaseTest {
AgentConfig agentConfig = new AgentConfig();
agentConfig.getTools().add(getLLMQueryTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setModelConfig(getLLMConfig(LLMType.GPT));
agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.GPT));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
agent.setMultiTurnConfig(multiTurnConfig);
@@ -119,58 +119,4 @@ public class Text2SQLEval extends BaseTest {
return ruleQueryTool;
}
private enum LLMType {
GPT,
MOONSHOT,
DEEPSEEK,
QWEN,
GLM
}
protected ChatModelConfig getLLMConfig(LLMType type) {
String baseUrl;
String apiKey;
String modelName;
double temperature = 0.0;
switch (type) {
case GLM:
baseUrl = "https://open.bigmodel.cn/api/paas/v4/";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "glm-4";
break;
case MOONSHOT:
baseUrl = "https://api.moonshot.cn/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "moonshot-v1-8k";
break;
case DEEPSEEK:
baseUrl = "https://api.deepseek.com";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "deepseek-coder";
break;
case QWEN:
baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "qwen-turbo";
temperature = 0.01;
break;
case GPT:
default:
baseUrl = "https://api.openai.com/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "gpt-3.5-turbo";
temperature = 0.0;
}
ChatModelConfig chatModel = new ChatModelConfig();
chatModel.setModelName(modelName);
chatModel.setBaseUrl(baseUrl);
chatModel.setApiKey(apiKey);
chatModel.setTemperature(temperature);
chatModel.setProvider("open_ai");
return chatModel;
}
}

View File

@@ -0,0 +1,182 @@
package com.tencent.supersonic.provider;
import com.tencent.supersonic.BaseApplication;
import com.tencent.supersonic.common.config.ParameterConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.EmbeddingModelConstant;
import dev.langchain4j.provider.InMemoryModelFactory;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.provider.OpenAiModelFactory;
import dev.langchain4j.provider.QianfanModelFactory;
import dev.langchain4j.provider.ZhipuModelFactory;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@Disabled
public class ModelProviderTest extends BaseApplication {
@Test
public void test_openai_chat_model_with_openapi() {
ChatModelConfig modelConfig = new ChatModelConfig();
modelConfig.setProvider(OpenAiModelFactory.PROVIDER);
modelConfig.setModelName(OpenAiModelFactory.DEFAULT_MODEL_NAME);
modelConfig.setBaseUrl(OpenAiModelFactory.DEFAULT_BASE_URL);
modelConfig.setApiKey(ParameterConfig.DEMO);
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
String response = chatModel.generate("hi");
assertNotNull(response);
}
@Test
public void test_qianfan_chat_model() {
ChatModelConfig modelConfig = new ChatModelConfig();
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
modelConfig.setModelName(QianfanModelFactory.DEFAULT_MODEL_NAME);
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
modelConfig.setSecretKey(ParameterConfig.DEMO);
modelConfig.setApiKey(ParameterConfig.DEMO);
modelConfig.setEndpoint(QianfanModelFactory.DEFAULT_ENDPOINT);
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
chatModel.generate("hi");
});
}
@Test
public void test_zhipu_chat_model() {
ChatModelConfig modelConfig = new ChatModelConfig();
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_MODEL_NAME);
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
chatModel.generate("hi");
});
}
@Test
public void test_dashscope_chat_model() {
ChatModelConfig modelConfig = new ChatModelConfig();
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_MODEL_NAME);
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
modelConfig.setEnableSearch(true);
modelConfig.setApiKey(ParameterConfig.DEMO);
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
chatModel.generate("hi");
});
}
@Test
public void test_azure_chat_model() {
ChatModelConfig modelConfig = new ChatModelConfig();
modelConfig.setProvider(AzureModelFactory.PROVIDER);
modelConfig.setModelName(AzureModelFactory.DEFAULT_MODEL_NAME);
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
modelConfig.setApiKey(ParameterConfig.DEMO);
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
chatModel.generate("hi");
});
}
@Test
public void test_in_memory_embedding_model() {
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
modelConfig.setProvider(InMemoryModelFactory.PROVIDER);
modelConfig.setModelName(EmbeddingModelConstant.BGE_SMALL_ZH);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
Response<Embedding> embed = embeddingModel.embed("hi");
assertNotNull(embed);
}
@Test
public void test_openai_embedding_model() {
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
modelConfig.setProvider(OpenAiModelFactory.PROVIDER);
modelConfig.setModelName(OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
modelConfig.setBaseUrl(OpenAiModelFactory.DEFAULT_BASE_URL);
modelConfig.setApiKey(ParameterConfig.DEMO);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
Response<Embedding> embed = embeddingModel.embed("hi");
assertNotNull(embed);
}
@Test
public void test_azure_embedding_model() {
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
modelConfig.setProvider(AzureModelFactory.PROVIDER);
modelConfig.setModelName(AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
modelConfig.setApiKey(ParameterConfig.DEMO);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
embeddingModel.embed("hi");
});
}
@Test
public void test_dashscope_embedding_model() {
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
modelConfig.setApiKey(ParameterConfig.DEMO);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
embeddingModel.embed("hi");
});
}
@Test
public void test_qianfan_embedding_model() {
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
modelConfig.setModelName(QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
modelConfig.setApiKey(ParameterConfig.DEMO);
modelConfig.setSecretKey(ParameterConfig.DEMO);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
embeddingModel.embed("hi");
});
}
@Test
public void test_zhipu_embedding_model() {
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
embeddingModel.embed("hi");
});
}
}

View File

@@ -0,0 +1,58 @@
package com.tencent.supersonic.util;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
public class LLMConfigUtils {
public enum LLMType {
GPT,
MOONSHOT,
DEEPSEEK,
QWEN,
GLM
}
public static ChatModelConfig getLLMConfig(LLMType type) {
String baseUrl;
String apiKey;
String modelName;
double temperature = 0.0;
switch (type) {
case GLM:
baseUrl = "https://open.bigmodel.cn/api/pas/v4/";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "glm-4";
break;
case MOONSHOT:
baseUrl = "https://api.moonshot.cn/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "moonshot-v1-8k";
break;
case DEEPSEEK:
baseUrl = "https://api.deepseek.com";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "deepseek-coder";
break;
case QWEN:
baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "qwen-turbo";
temperature = 0.01;
break;
case GPT:
default:
baseUrl = "https://api.openai.com/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "gpt-3.5-turbo";
temperature = 0.0;
}
ChatModelConfig chatModel = new ChatModelConfig();
chatModel.setModelName(modelName);
chatModel.setBaseUrl(baseUrl);
chatModel.setApiKey(apiKey);
chatModel.setTemperature(temperature);
chatModel.setProvider("open_ai");
return chatModel;
}
}