mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
[improvement][launcher]Add duration counter in Text2SQLEval.
This commit is contained in:
@@ -11,11 +11,9 @@ import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
|||||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import com.tencent.supersonic.util.LLMConfigUtils;
|
import com.tencent.supersonic.util.LLMConfigUtils;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.*;
|
||||||
import org.junit.jupiter.api.Disabled;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.api.TestInstance;
|
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||||
@@ -23,6 +21,7 @@ import java.util.stream.Collectors;
|
|||||||
public class Text2SQLEval extends BaseTest {
|
public class Text2SQLEval extends BaseTest {
|
||||||
|
|
||||||
private int agentId;
|
private int agentId;
|
||||||
|
private List<Long> durations = Lists.newArrayList();
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
public void init() {
|
public void init() {
|
||||||
@@ -30,16 +29,31 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
agentId = agent.getId();
|
agentId = agent.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@AfterAll
|
||||||
|
public void summarize() {
|
||||||
|
long total_duration = 0L;
|
||||||
|
for (Long duration : durations) {
|
||||||
|
total_duration += duration;
|
||||||
|
}
|
||||||
|
System.out.println(
|
||||||
|
String.format(
|
||||||
|
"Avg Duration: %d seconds", total_duration / 1000 / durations.size()));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_agg() throws Exception {
|
public void test_agg() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("近30天总访问次数", agentId);
|
QueryResult result = submitNewChat("近30天总访问次数", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 1;
|
assert result.getQueryColumns().size() == 1;
|
||||||
assert result.getQueryColumns().get(0).getName().contains("访问次数");
|
assert result.getQueryColumns().get(0).getName().contains("访问次数");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_agg_and_groupby() throws Exception {
|
public void test_agg_and_groupby() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("近30日每天的访问次数", agentId);
|
QueryResult result = submitNewChat("近30日每天的访问次数", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
assert result.getQueryColumns().size() == 2;
|
||||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date");
|
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date");
|
||||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||||
@@ -47,7 +61,9 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_drilldown() throws Exception {
|
public void test_drilldown() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId);
|
QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
assert result.getQueryColumns().size() == 2;
|
||||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
||||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||||
@@ -56,7 +72,9 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_drilldown_and_topN() throws Exception {
|
public void test_drilldown_and_topN() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId);
|
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
assert result.getQueryColumns().size() == 2;
|
||||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
||||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||||
@@ -65,7 +83,9 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_filter_and_top() throws Exception {
|
public void test_filter_and_top() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId);
|
QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
assert result.getQueryColumns().size() == 2;
|
||||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
assert result.getQueryColumns().get(0).getName().contains("用户");
|
||||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||||
@@ -74,7 +94,9 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_filter() throws Exception {
|
public void test_filter() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId);
|
QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() >= 1;
|
assert result.getQueryColumns().size() >= 1;
|
||||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
assert result.getQueryColumns().get(0).getName().contains("用户");
|
||||||
assert result.getQueryResults().size() == 2;
|
assert result.getQueryResults().size() == 2;
|
||||||
@@ -82,7 +104,9 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_filter_compare() throws Exception {
|
public void test_filter_compare() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("alice和lucy过去半个月哪一位的总停留时长更高", agentId);
|
QueryResult result = submitNewChat("alice和lucy过去半个月哪一位的总停留时长更高", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
assert result.getQueryColumns().size() == 2;
|
||||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
assert result.getQueryColumns().get(0).getName().contains("用户");
|
||||||
assert result.getQueryColumns().get(1).getName().contains("停留时长");
|
assert result.getQueryColumns().get(1).getName().contains("停留时长");
|
||||||
@@ -91,7 +115,9 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_term() throws Exception {
|
public void test_term() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
|
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() >= 1;
|
assert result.getQueryColumns().size() >= 1;
|
||||||
assert result.getQueryColumns().stream()
|
assert result.getQueryColumns().stream()
|
||||||
.filter(c -> c.getName().contains("停留时长"))
|
.filter(c -> c.getName().contains("停留时长"))
|
||||||
|
|||||||
Reference in New Issue
Block a user