From 6a8686f51318dad73dd3b821fd8d656674b3fac8 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Fri, 20 Sep 2024 17:24:00 +0800 Subject: [PATCH] [improvement][launcher]Add duration counter in `Text2SQLEval`. --- .../supersonic/evaluation/Text2SQLEval.java | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index ea36bd86b..035835b8a 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -11,11 +11,9 @@ import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.LLMConfigUtils; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.*; +import java.util.List; import java.util.stream.Collectors; @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -23,6 +21,7 @@ import java.util.stream.Collectors; public class Text2SQLEval extends BaseTest { private int agentId; + private List durations = Lists.newArrayList(); @BeforeAll public void init() { @@ -30,16 +29,31 @@ public class Text2SQLEval extends BaseTest { agentId = agent.getId(); } + @AfterAll + public void summarize() { + long total_duration = 0L; + for (Long duration : durations) { + total_duration += duration; + } + System.out.println( + String.format( + "Avg Duration: %d seconds", total_duration / 1000 / durations.size())); + } + @Test public void test_agg() throws Exception { + long start = System.currentTimeMillis(); QueryResult result = submitNewChat("近30天总访问次数", agentId); + durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() == 1; assert result.getQueryColumns().get(0).getName().contains("访问次数"); } @Test public void test_agg_and_groupby() throws Exception { + long start = System.currentTimeMillis(); QueryResult result = submitNewChat("近30日每天的访问次数", agentId); + durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() == 2; assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date"); assert result.getQueryColumns().get(1).getName().contains("访问次数"); @@ -47,7 +61,9 @@ public class Text2SQLEval extends BaseTest { @Test public void test_drilldown() throws Exception { + long start = System.currentTimeMillis(); QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId); + durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() == 2; assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门"); assert result.getQueryColumns().get(1).getName().contains("访问次数"); @@ -56,7 +72,9 @@ public class Text2SQLEval extends BaseTest { @Test public void test_drilldown_and_topN() throws Exception { + long start = System.currentTimeMillis(); QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId); + durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() == 2; assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门"); assert result.getQueryColumns().get(1).getName().contains("访问次数"); @@ -65,7 +83,9 @@ public class Text2SQLEval extends BaseTest { @Test public void test_filter_and_top() throws Exception { + long start = System.currentTimeMillis(); QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId); + durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() == 2; assert result.getQueryColumns().get(0).getName().contains("用户"); assert result.getQueryColumns().get(1).getName().contains("访问次数"); @@ -74,7 +94,9 @@ public class Text2SQLEval extends BaseTest { @Test public void test_filter() throws Exception { + long start = System.currentTimeMillis(); QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId); + durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() >= 1; assert result.getQueryColumns().get(0).getName().contains("用户"); assert result.getQueryResults().size() == 2; @@ -82,7 +104,9 @@ public class Text2SQLEval extends BaseTest { @Test public void test_filter_compare() throws Exception { + long start = System.currentTimeMillis(); QueryResult result = submitNewChat("alice和lucy过去半个月哪一位的总停留时长更高", agentId); + durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() == 2; assert result.getQueryColumns().get(0).getName().contains("用户"); assert result.getQueryColumns().get(1).getName().contains("停留时长"); @@ -91,7 +115,9 @@ public class Text2SQLEval extends BaseTest { @Test public void test_term() throws Exception { + long start = System.currentTimeMillis(); QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId); + durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() >= 1; assert result.getQueryColumns().stream() .filter(c -> c.getName().contains("停留时长"))