diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index 3a17e973a..befec2502 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -34,9 +34,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." + "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + "\n4.DO NOT calculate date range using functions." - + "\n5.DO NOT calculate date range using DATE_SUB." - + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." - + "\n7.ALWAYS USE `with` statement to handle secondary calculation scenario.\"" + + "\n5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + + "\n6.ALWAYS use `with` statement if nested aggregation is needed." + "\n#Exemplars:\n{{exemplar}}" + "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}"; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java index 3bf97a5c3..60f4097d8 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java @@ -274,17 +274,11 @@ public class SqlGenerateUtils { return modelBizName + UNDERLINE + executorConfig.getInternalMetricNameSuffix(); } - public String generateDerivedMetric( - final List metricResps, - final Set allFields, - final Map allMeasures, - final List dimensionResps, - final String expression, - final MetricDefineType metricDefineType, - AggOption aggOption, - Map visitedMetric, - Set measures, - Set dimensions) { + public String generateDerivedMetric(final List metricResps, + final Set allFields, final Map allMeasures, + final List dimensionResps, final String expression, + final MetricDefineType metricDefineType, AggOption aggOption, + Map visitedMetric, Set measures, Set dimensions) { Set fields = SqlSelectHelper.getColumnFromExpr(expression); if (!CollectionUtils.isEmpty(fields)) { Map replace = new HashMap<>(); @@ -298,19 +292,11 @@ public class SqlGenerateUtils { replace.put(field, visitedMetric.get(field)); break; } - replace.put( - field, - generateDerivedMetric( - metricResps, - allFields, - allMeasures, - dimensionResps, - getExpr(metricItem.get()), - metricItem.get().getMetricDefineType(), - aggOption, - visitedMetric, - measures, - dimensions)); + replace.put(field, + generateDerivedMetric(metricResps, allFields, allMeasures, + dimensionResps, getExpr(metricItem.get()), + metricItem.get().getMetricDefineType(), aggOption, + visitedMetric, measures, dimensions)); visitedMetric.put(field, replace.get(field)); } break; diff --git a/launchers/standalone/src/main/resources/s2-exemplar.json b/launchers/standalone/src/main/resources/s2-exemplar.json index 0210b7bf4..9cbf09c63 100644 --- a/launchers/standalone/src/main/resources/s2-exemplar.json +++ b/launchers/standalone/src/main/resources/s2-exemplar.json @@ -15,7 +15,7 @@ "question": "超音数过去90天美术部、技术研发部的访问时长", "sideInfo": "CurrentDate=[2023-04-21]", "dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]", - "sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'" + "sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-21' AND 数据日期 <= '2023-04-21'" }, { "question": "超音数访问时长小于1小时,且来自美术部的用户是哪些", diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 23f1a38d1..f1e6869f2 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -74,21 +74,19 @@ public class Text2SQLEval extends BaseTest { long start = System.currentTimeMillis(); QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agentId); durations.add(System.currentTimeMillis() - start); - assert result.getQueryColumns().size() == 2; - assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门"); - assert result.getQueryColumns().get(1).getName().contains("访问次数"); assert result.getQueryResults().size() == 3; + assert result.getTextResult().contains("marketing"); + assert result.getTextResult().contains("sales"); + assert result.getTextResult().contains("strategy"); } @Test public void test_filter_and_top() throws Exception { long start = System.currentTimeMillis(); - QueryResult result = submitNewChat("近半个月sales部门访问量最高的用户是谁", agentId); + QueryResult result = submitNewChat("近半个月来sales部门访问量最高的用户是谁", agentId); durations.add(System.currentTimeMillis() - start); - assert result.getQueryColumns().size() == 2; - assert result.getQueryColumns().get(0).getName().contains("用户"); - assert result.getQueryColumns().get(1).getName().contains("访问次数"); assert result.getQueryResults().size() == 1; + assert result.getTextResult().contains("tom"); } @Test @@ -96,31 +94,37 @@ public class Text2SQLEval extends BaseTest { long start = System.currentTimeMillis(); QueryResult result = submitNewChat("近一个月sales部门总访问次数超过10次的用户有哪些", agentId); durations.add(System.currentTimeMillis() - start); - assert result.getQueryColumns().size() >= 1; - assert result.getQueryColumns().get(0).getName().contains("用户"); assert result.getQueryResults().size() == 2; + assert result.getTextResult().contains("alice"); + assert result.getTextResult().contains("tom"); } @Test public void test_filter_compare() throws Exception { long start = System.currentTimeMillis(); - QueryResult result = submitNewChat("alice和lucy过去半个月哪一位的总停留时长更高", agentId); + QueryResult result = submitNewChat("alice和lucy过去半个月谁的总停留时长更多", agentId); durations.add(System.currentTimeMillis() - start); - assert result.getQueryColumns().size() == 2; - assert result.getQueryColumns().get(0).getName().contains("用户"); - assert result.getQueryColumns().get(1).getName().contains("停留时长"); assert result.getQueryResults().size() >= 1; + assert result.getTextResult().contains("alice"); } @Test public void test_term() throws Exception { long start = System.currentTimeMillis(); - QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId); + QueryResult result = submitNewChat("过去半个月每个核心用户的总停留时长", agentId); durations.add(System.currentTimeMillis() - start); - assert result.getQueryColumns().size() >= 1; - assert result.getQueryColumns().stream().filter(c -> c.getName().contains("停留时长")) - .collect(Collectors.toList()).size() == 1; - assert result.getQueryResults().size() >= 1; + assert result.getQueryResults().size() == 2; + assert result.getTextResult().contains("tom"); + assert result.getTextResult().contains("lucy"); + } + + @Test + public void test_second_calculation() throws Exception { + long start = System.currentTimeMillis(); + QueryResult result = submitNewChat("近1个月总访问次数超过100次的部门有几个", agentId); + durations.add(System.currentTimeMillis() - start); + assert result.getQueryColumns().size() == 1; + assert result.getTextResult().contains("3"); } public Agent getLLMAgent(boolean enableMultiturn) { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java index 68a7a8126..63532d50b 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java @@ -59,7 +59,7 @@ public class LLMConfigUtils { default: baseUrl = "https://api.openai.com/v1"; apiKey = "REPLACE_WITH_YOUR_KEY"; - modelName = "gpt-3.5-turbo"; + modelName = "gpt-4o"; temperature = 0.0; } diff --git a/launchers/standalone/src/test/resources/s2-exemplar.json b/launchers/standalone/src/test/resources/s2-exemplar.json index 0210b7bf4..9cbf09c63 100644 --- a/launchers/standalone/src/test/resources/s2-exemplar.json +++ b/launchers/standalone/src/test/resources/s2-exemplar.json @@ -15,7 +15,7 @@ "question": "超音数过去90天美术部、技术研发部的访问时长", "sideInfo": "CurrentDate=[2023-04-21]", "dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]", - "sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'" + "sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-21' AND 数据日期 <= '2023-04-21'" }, { "question": "超音数访问时长小于1小时,且来自美术部的用户是哪些",