[improvement][launcher]Add duration counter in Text2SQLEval.

This commit is contained in:
jerryjzhang
2024-09-20 17:24:00 +08:00
parent c045b34328
commit 6a8686f513

View File

@@ -11,11 +11,9 @@ import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.util.DataUtils;
import com.tencent.supersonic.util.LLMConfigUtils;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.*;
import java.util.List;
import java.util.stream.Collectors;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@@ -23,6 +21,7 @@ import java.util.stream.Collectors;
public class Text2SQLEval extends BaseTest {
private int agentId;
private List<Long> durations = Lists.newArrayList();
@BeforeAll
public void init() {
@@ -30,16 +29,31 @@ public class Text2SQLEval extends BaseTest {
agentId = agent.getId();
}
@AfterAll
public void summarize() {
long total_duration = 0L;
for (Long duration : durations) {
total_duration += duration;
}
System.out.println(
String.format(
"Avg Duration: %d seconds", total_duration / 1000 / durations.size()));
}
@Test
public void test_agg() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("近30天总访问次数", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 1;
assert result.getQueryColumns().get(0).getName().contains("访问次数");
}
@Test
public void test_agg_and_groupby() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("近30日每天的访问次数", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
@@ -47,7 +61,9 @@ public class Text2SQLEval extends BaseTest {
@Test
public void test_drilldown() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
@@ -56,7 +72,9 @@ public class Text2SQLEval extends BaseTest {
@Test
public void test_drilldown_and_topN() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
@@ -65,7 +83,9 @@ public class Text2SQLEval extends BaseTest {
@Test
public void test_filter_and_top() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().contains("用户");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
@@ -74,7 +94,9 @@ public class Text2SQLEval extends BaseTest {
@Test
public void test_filter() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() >= 1;
assert result.getQueryColumns().get(0).getName().contains("用户");
assert result.getQueryResults().size() == 2;
@@ -82,7 +104,9 @@ public class Text2SQLEval extends BaseTest {
@Test
public void test_filter_compare() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("alice和lucy过去半个月哪一位的总停留时长更高", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().contains("用户");
assert result.getQueryColumns().get(1).getName().contains("停留时长");
@@ -91,7 +115,9 @@ public class Text2SQLEval extends BaseTest {
@Test
public void test_term() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() >= 1;
assert result.getQueryColumns().stream()
.filter(c -> c.getName().contains("停留时长"))