From 39b5dde11d95572418fbdb4a1a329bb69e8fe2cf Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sat, 15 Jun 2024 01:37:40 +0800 Subject: [PATCH] (improvement)(launcher)Introduce Text2SQLEval to facilitate evaluation of different prompting strategies or different LLMs. #1152 --- .../supersonic/evaluation/Text2SQLEval.java | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java new file mode 100644 index 000000000..569f51ab4 --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -0,0 +1,150 @@ +package com.tencent.supersonic.evaluation; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.BaseTest; +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.agent.AgentConfig; +import com.tencent.supersonic.chat.server.agent.AgentToolType; +import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; +import com.tencent.supersonic.chat.server.agent.RuleParserTool; +import com.tencent.supersonic.headless.api.pojo.LLMConfig; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; +import com.tencent.supersonic.util.DataUtils; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@Disabled +public class Text2SQLEval extends BaseTest { + + private int agentId; + + @BeforeAll + public void init() { + agentId = agentService.createAgent(getLLMAgent(false), DataUtils.getUser()); + } + + @Test + public void test_agg() throws Exception { + QueryResult result = submitNewChat("近30天访问次数", agentId); + assert result.getQueryColumns().size() == 1; + assert result.getQueryColumns().get(0).getName().contains("访问次数"); + } + + @Test + public void test_agg_and_groupby() throws Exception { + QueryResult result = submitNewChat("近30日每天的访问次数", agentId); + assert result.getQueryColumns().size() == 2; + assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date"); + assert result.getQueryColumns().get(1).getName().contains("访问次数"); + } + + @Test + public void test_drilldown() throws Exception { + QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId); + assert result.getQueryColumns().size() == 2; + assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门"); + assert result.getQueryColumns().get(1).getName().contains("访问次数"); + assert result.getQueryResults().size() == 4; + } + + @Test + public void test_drilldown_and_topN() throws Exception { + QueryResult result = submitNewChat("过去30天访问次数最高的部门top2", agentId); + assert result.getQueryColumns().size() == 2; + assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门"); + assert result.getQueryColumns().get(1).getName().contains("访问次数"); + assert result.getQueryResults().size() == 2; + } + + @Test + public void test_filter_and_top() throws Exception { + QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId); + assert result.getQueryColumns().size() == 2; + assert result.getQueryColumns().get(0).getName().contains("用户"); + assert result.getQueryColumns().get(1).getName().contains("访问次数"); + assert result.getQueryResults().size() == 1; + } + + @Test + public void test_filter() throws Exception { + QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId); + assert result.getQueryColumns().size() == 2; + assert result.getQueryColumns().get(0).getName().contains("用户"); + assert result.getQueryColumns().get(1).getName().contains("访问次数"); + assert result.getQueryResults().size() == 2; + } + + public static Agent getLLMAgent(boolean enableMultiturn) { + Agent agent = new Agent(); + agent.setName("Agent for Test"); + AgentConfig agentConfig = new AgentConfig(); + agentConfig.getTools().add(getLLMQueryTool()); + agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); + agent.setLlmConfig(getLLMConfig(LLMType.GPT)); + MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); + multiTurnConfig.setEnableMultiTurn(enableMultiturn); + agent.setMultiTurnConfig(multiTurnConfig); + return agent; + } + + private static RuleParserTool getLLMQueryTool() { + RuleParserTool ruleQueryTool = new RuleParserTool(); + ruleQueryTool.setType(AgentToolType.NL2SQL_LLM); + ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L)); + + return ruleQueryTool; + } + + private enum LLMType { + GPT, + MOONSHOT, + DEEPSEEK, + QWEN, + GLM + } + + private static LLMConfig getLLMConfig(LLMType type) { + String baseUrl; + String apiKey; + String modelName; + double temperature = 0.0; + + switch (type) { + case GLM: + baseUrl = "https://open.bigmodel.cn/api/paas/v4/"; + apiKey = "REPLACE_WITH_YOUR_KEY"; + modelName = "glm-4"; + break; + case MOONSHOT: + baseUrl = "https://api.moonshot.cn/v1"; + apiKey = "REPLACE_WITH_YOUR_KEY"; + modelName = "moonshot-v1-8k"; + break; + case DEEPSEEK: + baseUrl = "https://api.deepseek.com"; + apiKey = "REPLACE_WITH_YOUR_KEY"; + modelName = "deepseek-coder"; + break; + case QWEN: + baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1"; + apiKey = "REPLACE_WITH_YOUR_KEY"; + modelName = "qwen-turbo"; + temperature = 0.01; + break; + case GPT: + default: + baseUrl = "https://api.openai.com/v1"; + apiKey = "REPLACE_WITH_YOUR_KEY"; + modelName = "gpt-3.5-turbo"; + temperature = 0.0; + break; + } + + return new LLMConfig("open_ai", + baseUrl, apiKey, modelName, temperature); + } +}