diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java index a0c375698..ccf5fb701 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java @@ -6,6 +6,8 @@ import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOrderBy; +import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.SqlWith; import org.apache.calcite.sql.SqlWithItem; import org.apache.calcite.sql.SqlWriterConfig; @@ -42,30 +44,46 @@ public class SqlMergeWithUtils { // Create a new WITH item for parentWithName without quotes SqlWithItem withItem = new SqlWithItem(SqlParserPos.ZERO, - new SqlIdentifier(parentWithName, SqlParserPos.ZERO), // false - // to - // avoid - // quotes - null, sqlNode2, SqlLiteral.createBoolean(false, SqlParserPos.ZERO)); + new SqlIdentifier(parentWithName, SqlParserPos.ZERO), null, sqlNode2, + SqlLiteral.createBoolean(false, SqlParserPos.ZERO)); // Add the new WITH item to the list withItemList.add(withItem); } + // Check if the main SQL node contains a LIMIT clause + SqlNode limitNode = null; + if (sqlNode1 instanceof SqlOrderBy) { + SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1; + limitNode = sqlOrderBy.fetch; + sqlNode1 = sqlOrderBy.query; + } else if (sqlNode1 instanceof SqlSelect) { + SqlSelect sqlSelect = (SqlSelect) sqlNode1; + limitNode = sqlSelect.getFetch(); + sqlSelect.setFetch(null); + sqlNode1 = sqlSelect; + } // Extract existing WITH items from sqlNode1 if it is a SqlWith if (sqlNode1 instanceof SqlWith) { SqlWith sqlWith = (SqlWith) sqlNode1; withItemList.addAll(sqlWith.withList.getList()); sqlNode1 = sqlWith.body; } - // Create a new SqlWith node SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO, new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1); + + // If there was a LIMIT clause, wrap the finalSqlNode in a SqlOrderBy with the LIMIT + SqlNode resultNode = finalSqlNode; + if (limitNode != null) { + resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode, SqlNodeList.EMPTY, null, + limitNode); + } + // Custom SqlPrettyWriter configuration to avoid quoting identifiers SqlWriterConfig config = Configuration.getSqlWriterConfig(engineType); // Pretty print the final SQL SqlPrettyWriter writer = new SqlPrettyWriter(config); - return writer.format(finalSqlNode); + return writer.format(resultNode); } } diff --git a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java index 76a16c5c1..53e8eef13 100644 --- a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.common.calcite; import com.tencent.supersonic.common.pojo.enums.EngineType; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.Assert; import org.junit.jupiter.api.Test; import java.util.Collections; @@ -11,8 +12,7 @@ import java.util.Collections; class SqlWithMergerTest { @Test - void testWithMerger() throws SqlParseException { - + void test1() throws SqlParseException { String sql1 = "WITH DepartmentVisits AS (\n" + " SELECT department, SUM(pv) AS 总访问次数\n" + " FROM t_1\n" + " WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29'\n" @@ -28,20 +28,117 @@ class SqlWithMergerTest { String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1, Collections.singletonList(sql2), Collections.singletonList("t_1")); - System.out.println(mergeSql); + Assert.assertEquals(format(mergeSql), + "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` " + + "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN " + + "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) " + + "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`), DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 " + + "FROM t_1 WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29' GROUP BY department) " + + "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100"); + } - sql1 = "WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' " - + "AND sys_imp_date <= '2024-09-28' GROUP BY department) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000"; + @Test + void test2() throws SqlParseException { - sql2 = "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" - + "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n" - + "`s2_user_department`) AS `t2`\n" - + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" - + "FROM\n" + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; + String sql1 = + "WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' " + + "AND sys_imp_date <= '2024-09-28' GROUP BY department) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000"; - mergeSql = SqlMergeWithUtils.mergeWith(EngineType.H2, sql1, Collections.singletonList(sql2), - Collections.singletonList("t_1")); + String sql2 = + "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" + + "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n" + + "`s2_user_department`) AS `t2`\n" + + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" + + "FROM\n" + + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; - System.out.println(mergeSql); + String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1, + Collections.singletonList(sql2), Collections.singletonList("t_1")); + + Assert.assertEquals(format(mergeSql), + "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` " + + "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN " + + "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON " + + "`t2`.`user_name` = `t3`.`user_name`), DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 " + + "WHERE sys_imp_date >= '2024-08-28' AND sys_imp_date <= '2024-09-28' GROUP BY department) SELECT COUNT(*) " + + "FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000"); + } + + @Test + void test3() throws SqlParseException { + + String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000"; + + String sql2 = + "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" + + "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n" + + "`s2_user_department`) AS `t2`\n" + + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" + + "FROM\n" + + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; + + String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1, + Collections.singletonList(sql2), Collections.singletonList("t_1")); + + Assert.assertEquals(format(mergeSql), + "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` " + + "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN " + + "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) " + + "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 " + + "LIMIT 1000"); + } + + @Test + void test4() throws SqlParseException { + String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100"; + + String sql2 = + "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" + + "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n" + + "`s2_user_department`) AS `t2`\n" + + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" + + "FROM\n" + + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; + + String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1, + Collections.singletonList(sql2), Collections.singletonList("t_1")); + + Assert.assertEquals(format(mergeSql), + "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` " + + "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` " + + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) " + + "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100"); + + } + + @Test + void test5() throws SqlParseException { + + String sql1 = "SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100"; + + String sql2 = + "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" + + "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n" + + "`s2_user_department`) AS `t2`\n" + + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" + + "FROM\n" + + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; + + String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1, + Collections.singletonList(sql2), Collections.singletonList("t_1")); + + + Assert.assertEquals(format(mergeSql), + "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` " + + "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN " + + "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) " + + "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM Department INNER JOIN Visits " + + "WHERE 总访问次数 > 100"); + } + + private static String format(String mergeSql) { + mergeSql = mergeSql.replace("\r\n", "\n"); + // Remove extra spaces and newlines + return mergeSql.replaceAll("\\s+", " ").trim(); } }