mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][launcher]Add duration counter in Text2SQLEval.
This commit is contained in:
@@ -11,11 +11,9 @@ import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import com.tencent.supersonic.util.LLMConfigUtils;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
import org.junit.jupiter.api.*;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
@@ -23,6 +21,7 @@ import java.util.stream.Collectors;
|
||||
public class Text2SQLEval extends BaseTest {
|
||||
|
||||
private int agentId;
|
||||
private List<Long> durations = Lists.newArrayList();
|
||||
|
||||
@BeforeAll
|
||||
public void init() {
|
||||
@@ -30,16 +29,31 @@ public class Text2SQLEval extends BaseTest {
|
||||
agentId = agent.getId();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public void summarize() {
|
||||
long total_duration = 0L;
|
||||
for (Long duration : durations) {
|
||||
total_duration += duration;
|
||||
}
|
||||
System.out.println(
|
||||
String.format(
|
||||
"Avg Duration: %d seconds", total_duration / 1000 / durations.size()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_agg() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("近30天总访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 1;
|
||||
assert result.getQueryColumns().get(0).getName().contains("访问次数");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_agg_and_groupby() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("近30日每天的访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date");
|
||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||
@@ -47,7 +61,9 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_drilldown() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||
@@ -56,7 +72,9 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_drilldown_and_topN() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||
@@ -65,7 +83,9 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_filter_and_top() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||
@@ -74,7 +94,9 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_filter() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() >= 1;
|
||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
||||
assert result.getQueryResults().size() == 2;
|
||||
@@ -82,7 +104,9 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_filter_compare() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("alice和lucy过去半个月哪一位的总停留时长更高", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
||||
assert result.getQueryColumns().get(1).getName().contains("停留时长");
|
||||
@@ -91,7 +115,9 @@ public class Text2SQLEval extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_term() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() >= 1;
|
||||
assert result.getQueryColumns().stream()
|
||||
.filter(c -> c.getName().contains("停留时长"))
|
||||
|
||||
Reference in New Issue
Block a user