[improvement][headless-chat]Prompt LLM to generate with SQL statement to handle secondary calculation scenario.#1718

This commit is contained in:
jerryjzhang
2024-10-08 16:05:43 +08:00
parent 8bf5d395a7
commit 27e654a873
6 changed files with 37 additions and 48 deletions

View File

@@ -15,7 +15,7 @@
"question": "超音数过去90天美术部、技术研发部的访问时长",
"sideInfo": "CurrentDate=[2023-04-21]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'"
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-21' AND 数据日期 <= '2023-04-21'"
},
{
"question": "超音数访问时长小于1小时且来自美术部的用户是哪些",

View File

@@ -74,21 +74,19 @@ public class Text2SQLEval extends BaseTest {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
assert result.getQueryResults().size() == 3;
assert result.getTextResult().contains("marketing");
assert result.getTextResult().contains("sales");
assert result.getTextResult().contains("strategy");
}
@Test
public void test_filter_and_top() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId);
QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().contains("用户");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
assert result.getQueryResults().size() == 1;
assert result.getTextResult().contains("tom");
}
@Test
@@ -96,31 +94,37 @@ public class Text2SQLEval extends BaseTest {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() >= 1;
assert result.getQueryColumns().get(0).getName().contains("用户");
assert result.getQueryResults().size() == 2;
assert result.getTextResult().contains("alice");
assert result.getTextResult().contains("tom");
}
@Test
public void test_filter_compare() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("alice和lucy过去半个月哪一位的总停留时长更", agentId);
QueryResult result = submitNewChat("alice和lucy过去半个月的总停留时长更", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().contains("用户");
assert result.getQueryColumns().get(1).getName().contains("停留时长");
assert result.getQueryResults().size() >= 1;
assert result.getTextResult().contains("alice");
}
@Test
public void test_term() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
QueryResult result = submitNewChat("过去半个月每个核心用户的总停留时长", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() >= 1;
assert result.getQueryColumns().stream().filter(c -> c.getName().contains("停留时长"))
.collect(Collectors.toList()).size() == 1;
assert result.getQueryResults().size() >= 1;
assert result.getQueryResults().size() == 2;
assert result.getTextResult().contains("tom");
assert result.getTextResult().contains("lucy");
}
@Test
public void test_second_calculation() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("近1个月总访问次数超过100次的部门有几个", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 1;
assert result.getTextResult().contains("3");
}
public Agent getLLMAgent(boolean enableMultiturn) {

View File

@@ -59,7 +59,7 @@ public class LLMConfigUtils {
default:
baseUrl = "https://api.openai.com/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "gpt-3.5-turbo";
modelName = "gpt-4o";
temperature = 0.0;
}

View File

@@ -15,7 +15,7 @@
"question": "超音数过去90天美术部、技术研发部的访问时长",
"sideInfo": "CurrentDate=[2023-04-21]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'"
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-21' AND 数据日期 <= '2023-04-21'"
},
{
"question": "超音数访问时长小于1小时且来自美术部的用户是哪些",