mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
[improvement][headless-chat]Prompt LLM to generate with SQL statement to handle secondary calculation scenario.#1718
This commit is contained in:
@@ -34,9 +34,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
+ "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
+ "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||||
+ "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
+ "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||||
+ "\n4.DO NOT calculate date range using functions."
|
+ "\n4.DO NOT calculate date range using functions."
|
||||||
+ "\n5.DO NOT calculate date range using DATE_SUB."
|
+ "\n5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||||
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
+ "\n6.ALWAYS use `with` statement if nested aggregation is needed."
|
||||||
+ "\n7.ALWAYS USE `with` statement to handle secondary calculation scenario.\""
|
|
||||||
+ "\n#Exemplars:\n{{exemplar}}"
|
+ "\n#Exemplars:\n{{exemplar}}"
|
||||||
+ "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
+ "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
||||||
|
|
||||||
|
|||||||
@@ -274,17 +274,11 @@ public class SqlGenerateUtils {
|
|||||||
return modelBizName + UNDERLINE + executorConfig.getInternalMetricNameSuffix();
|
return modelBizName + UNDERLINE + executorConfig.getInternalMetricNameSuffix();
|
||||||
}
|
}
|
||||||
|
|
||||||
public String generateDerivedMetric(
|
public String generateDerivedMetric(final List<MetricSchemaResp> metricResps,
|
||||||
final List<MetricSchemaResp> metricResps,
|
final Set<String> allFields, final Map<String, Measure> allMeasures,
|
||||||
final Set<String> allFields,
|
final List<DimSchemaResp> dimensionResps, final String expression,
|
||||||
final Map<String, Measure> allMeasures,
|
final MetricDefineType metricDefineType, AggOption aggOption,
|
||||||
final List<DimSchemaResp> dimensionResps,
|
Map<String, String> visitedMetric, Set<String> measures, Set<String> dimensions) {
|
||||||
final String expression,
|
|
||||||
final MetricDefineType metricDefineType,
|
|
||||||
AggOption aggOption,
|
|
||||||
Map<String, String> visitedMetric,
|
|
||||||
Set<String> measures,
|
|
||||||
Set<String> dimensions) {
|
|
||||||
Set<String> fields = SqlSelectHelper.getColumnFromExpr(expression);
|
Set<String> fields = SqlSelectHelper.getColumnFromExpr(expression);
|
||||||
if (!CollectionUtils.isEmpty(fields)) {
|
if (!CollectionUtils.isEmpty(fields)) {
|
||||||
Map<String, String> replace = new HashMap<>();
|
Map<String, String> replace = new HashMap<>();
|
||||||
@@ -298,19 +292,11 @@ public class SqlGenerateUtils {
|
|||||||
replace.put(field, visitedMetric.get(field));
|
replace.put(field, visitedMetric.get(field));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
replace.put(
|
replace.put(field,
|
||||||
field,
|
generateDerivedMetric(metricResps, allFields, allMeasures,
|
||||||
generateDerivedMetric(
|
dimensionResps, getExpr(metricItem.get()),
|
||||||
metricResps,
|
metricItem.get().getMetricDefineType(), aggOption,
|
||||||
allFields,
|
visitedMetric, measures, dimensions));
|
||||||
allMeasures,
|
|
||||||
dimensionResps,
|
|
||||||
getExpr(metricItem.get()),
|
|
||||||
metricItem.get().getMetricDefineType(),
|
|
||||||
aggOption,
|
|
||||||
visitedMetric,
|
|
||||||
measures,
|
|
||||||
dimensions));
|
|
||||||
visitedMetric.put(field, replace.get(field));
|
visitedMetric.put(field, replace.get(field));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
"question": "超音数过去90天美术部、技术研发部的访问时长",
|
"question": "超音数过去90天美术部、技术研发部的访问时长",
|
||||||
"sideInfo": "CurrentDate=[2023-04-21]",
|
"sideInfo": "CurrentDate=[2023-04-21]",
|
||||||
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
|
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
|
||||||
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'"
|
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-21' AND 数据日期 <= '2023-04-21'"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"question": "超音数访问时长小于1小时,且来自美术部的用户是哪些",
|
"question": "超音数访问时长小于1小时,且来自美术部的用户是哪些",
|
||||||
|
|||||||
@@ -74,21 +74,19 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId);
|
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId);
|
||||||
durations.add(System.currentTimeMillis() - start);
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
|
||||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
|
||||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
|
||||||
assert result.getQueryResults().size() == 3;
|
assert result.getQueryResults().size() == 3;
|
||||||
|
assert result.getTextResult().contains("marketing");
|
||||||
|
assert result.getTextResult().contains("sales");
|
||||||
|
assert result.getTextResult().contains("strategy");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_filter_and_top() throws Exception {
|
public void test_filter_and_top() throws Exception {
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId);
|
QueryResult result = submitNewChat("近半个月来sales部门访问量最高的用户是谁", agentId);
|
||||||
durations.add(System.currentTimeMillis() - start);
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
|
||||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
|
||||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
|
||||||
assert result.getQueryResults().size() == 1;
|
assert result.getQueryResults().size() == 1;
|
||||||
|
assert result.getTextResult().contains("tom");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -96,31 +94,37 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId);
|
QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId);
|
||||||
durations.add(System.currentTimeMillis() - start);
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() >= 1;
|
|
||||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
|
||||||
assert result.getQueryResults().size() == 2;
|
assert result.getQueryResults().size() == 2;
|
||||||
|
assert result.getTextResult().contains("alice");
|
||||||
|
assert result.getTextResult().contains("tom");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_filter_compare() throws Exception {
|
public void test_filter_compare() throws Exception {
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("alice和lucy过去半个月哪一位的总停留时长更高", agentId);
|
QueryResult result = submitNewChat("alice和lucy过去半个月谁的总停留时长更多", agentId);
|
||||||
durations.add(System.currentTimeMillis() - start);
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
|
||||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
|
||||||
assert result.getQueryColumns().get(1).getName().contains("停留时长");
|
|
||||||
assert result.getQueryResults().size() >= 1;
|
assert result.getQueryResults().size() >= 1;
|
||||||
|
assert result.getTextResult().contains("alice");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_term() throws Exception {
|
public void test_term() throws Exception {
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
|
QueryResult result = submitNewChat("过去半个月每个核心用户的总停留时长", agentId);
|
||||||
durations.add(System.currentTimeMillis() - start);
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() >= 1;
|
assert result.getQueryResults().size() == 2;
|
||||||
assert result.getQueryColumns().stream().filter(c -> c.getName().contains("停留时长"))
|
assert result.getTextResult().contains("tom");
|
||||||
.collect(Collectors.toList()).size() == 1;
|
assert result.getTextResult().contains("lucy");
|
||||||
assert result.getQueryResults().size() >= 1;
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test_second_calculation() throws Exception {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
|
QueryResult result = submitNewChat("近1个月总访问次数超过100次的部门有几个", agentId);
|
||||||
|
durations.add(System.currentTimeMillis() - start);
|
||||||
|
assert result.getQueryColumns().size() == 1;
|
||||||
|
assert result.getTextResult().contains("3");
|
||||||
}
|
}
|
||||||
|
|
||||||
public Agent getLLMAgent(boolean enableMultiturn) {
|
public Agent getLLMAgent(boolean enableMultiturn) {
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ public class LLMConfigUtils {
|
|||||||
default:
|
default:
|
||||||
baseUrl = "https://api.openai.com/v1";
|
baseUrl = "https://api.openai.com/v1";
|
||||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||||
modelName = "gpt-3.5-turbo";
|
modelName = "gpt-4o";
|
||||||
temperature = 0.0;
|
temperature = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
"question": "超音数过去90天美术部、技术研发部的访问时长",
|
"question": "超音数过去90天美术部、技术研发部的访问时长",
|
||||||
"sideInfo": "CurrentDate=[2023-04-21]",
|
"sideInfo": "CurrentDate=[2023-04-21]",
|
||||||
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
|
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
|
||||||
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'"
|
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-21' AND 数据日期 <= '2023-04-21'"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"question": "超音数访问时长小于1小时,且来自美术部的用户是哪些",
|
"question": "超音数访问时长小于1小时,且来自美术部的用户是哪些",
|
||||||
|
|||||||
Reference in New Issue
Block a user