diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java index 38f745c21..34922774a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java @@ -2,19 +2,11 @@ package com.tencent.supersonic.common.calcite; import com.tencent.supersonic.common.pojo.enums.EngineType; import lombok.extern.slf4j.Slf4j; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlLiteral; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlNodeList; -import org.apache.calcite.sql.SqlOrderBy; -import org.apache.calcite.sql.SqlSelect; -import org.apache.calcite.sql.SqlWith; -import org.apache.calcite.sql.SqlWithItem; -import org.apache.calcite.sql.SqlWriterConfig; -import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.pretty.SqlPrettyWriter; +import net.sf.jsqlparser.expression.Alias; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.select.ParenthesedSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.WithItem; import java.util.ArrayList; import java.util.List; @@ -22,85 +14,37 @@ import java.util.List; @Slf4j public class SqlMergeWithUtils { public static String mergeWith(EngineType engineType, String sql, List parentSqlList, - List parentWithNameList) throws SqlParseException { - SqlParser.Config parserConfig = Configuration.getParserConfig(engineType); + List parentWithNameList) throws Exception { - // Parse the main SQL statement - SqlParser parser = SqlParser.create(sql, parserConfig); - SqlNode sqlNode1 = parser.parseQuery(); + Select selectStatement = (Select) CCJSqlParserUtil.parse(sql); + List withItemList = new ArrayList<>(); - // List to hold all WITH items - List withItemList = new ArrayList<>(); - - // Iterate over each parentSql and parentWithName pair for (int i = 0; i < parentSqlList.size(); i++) { String parentSql = parentSqlList.get(i); String parentWithName = parentWithNameList.get(i); - // Parse the parent SQL statement - parser = SqlParser.create(parentSql, parserConfig); - SqlNode sqlNode2 = parser.parseQuery(); + Select parentSelect = (Select) CCJSqlParserUtil.parse(parentSql); + ParenthesedSelect select = new ParenthesedSelect(); + select.setSelect(parentSelect); // Create a new WITH item for parentWithName without quotes - SqlWithItem withItem = new SqlWithItem(SqlParserPos.ZERO, - new SqlIdentifier(parentWithName, SqlParserPos.ZERO), null, sqlNode2, - SqlLiteral.createBoolean(false, SqlParserPos.ZERO)); + WithItem withItem = new WithItem(); + withItem.setAlias(new Alias(parentWithName)); + withItem.setSelect(select); // Add the new WITH item to the list withItemList.add(withItem); } - // Check if the main SQL node contains an ORDER BY or LIMIT clause - SqlNode limitNode = null; - SqlNodeList orderByList = null; - if (sqlNode1 instanceof SqlOrderBy) { - SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1; - limitNode = sqlOrderBy.fetch; - orderByList = sqlOrderBy.orderList; - sqlNode1 = sqlOrderBy.query; - } else if (sqlNode1 instanceof SqlSelect) { - SqlSelect sqlSelect = (SqlSelect) sqlNode1; - limitNode = sqlSelect.getFetch(); - sqlSelect.setFetch(null); - sqlNode1 = sqlSelect; + // Extract existing WITH items from mainSelectBody if it has any + if (selectStatement.getWithItemsList() != null) { + withItemList.addAll(selectStatement.getWithItemsList()); } - // Extract existing WITH items from sqlNode1 if it is a SqlWith - if (sqlNode1 instanceof SqlWith) { - SqlWith sqlWith = (SqlWith) sqlNode1; - withItemList.addAll(sqlWith.withList.getList()); - sqlNode1 = sqlWith.body; - } + // Set the new WITH items list to the main select body + selectStatement.setWithItemsList(withItemList); - // Create a new SqlWith node - SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO, - new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1); - - // If there was an ORDER BY or LIMIT clause, wrap the finalSqlNode in a SqlOrderBy - SqlNode resultNode = finalSqlNode; - if (orderByList != null || limitNode != null) { - resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode, - orderByList != null ? orderByList : SqlNodeList.EMPTY, null, limitNode); - } - - // Custom SqlPrettyWriter configuration to avoid quoting identifiers - SqlWriterConfig config = Configuration.getSqlWriterConfig(engineType); // Pretty print the final SQL - SqlPrettyWriter writer = new SqlPrettyWriter(config); - return writer.format(resultNode); - } - - public static boolean hasWith(EngineType engineType, String sql) throws SqlParseException { - SqlParser.Config parserConfig = Configuration.getParserConfig(engineType); - SqlParser parser = SqlParser.create(sql, parserConfig); - SqlNode sqlNode = parser.parseQuery(); - SqlNode sqlSelect = sqlNode; - if (sqlNode instanceof SqlOrderBy) { - SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode; - sqlSelect = sqlOrderBy.query; - } else if (sqlNode instanceof SqlSelect) { - sqlSelect = (SqlSelect) sqlNode; - } - return sqlSelect instanceof SqlWith; + return selectStatement.toString(); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index 11666a86e..fc4281a6e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -989,6 +989,15 @@ public class SqlSelectHelper { for (SelectItem selectItem : selectItems) { selectItem.accept(visitor); } + if (plainSelect.getHaving() != null) { + plainSelect.getHaving().accept(visitor); + } + if (!CollectionUtils.isEmpty(plainSelect.getOrderByElements())) { + for (OrderByElement orderByElement : plainSelect.getOrderByElements()) { + orderByElement.getExpression().accept(visitor); + } + } + return !visitor.getFunctionNames().isEmpty(); } diff --git a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java index 5da8fc84b..e05135902 100644 --- a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.common.calcite; import com.tencent.supersonic.common.pojo.enums.EngineType; import lombok.extern.slf4j.Slf4j; -import org.apache.calcite.sql.parser.SqlParseException; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -12,7 +11,7 @@ import java.util.Collections; class SqlWithMergerTest { @Test - void test1() throws SqlParseException { + void test1() throws Exception { String sql1 = "WITH DepartmentVisits AS (\n" + " SELECT department, SUM(pv) AS 总访问次数\n" + " FROM t_1\n" + " WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29'\n" @@ -38,7 +37,7 @@ class SqlWithMergerTest { } @Test - void test2() throws SqlParseException { + void test2() throws Exception { String sql1 = "WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' " @@ -65,7 +64,7 @@ class SqlWithMergerTest { } @Test - void test3() throws SqlParseException { + void test3() throws Exception { String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000"; @@ -89,7 +88,7 @@ class SqlWithMergerTest { } @Test - void test4() throws SqlParseException { + void test4() throws Exception { String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100"; String sql2 = @@ -112,7 +111,7 @@ class SqlWithMergerTest { } @Test - void test5() throws SqlParseException { + void test5() throws Exception { String sql1 = "SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100"; @@ -132,13 +131,13 @@ class SqlWithMergerTest { "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` " + "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN " + "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) " - + "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM Department INNER JOIN Visits " + + "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM Department JOIN Visits " + "WHERE 总访问次数 > 100"); } @Test - void test6() throws SqlParseException { + void test6() throws Exception { String sql1 = "SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10"; @@ -159,7 +158,36 @@ class SqlWithMergerTest { "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM " + "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`," + " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) " - + "SELECT COUNT(*) FROM Department INNER JOIN Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10"); + + "SELECT COUNT(*) FROM Department JOIN Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10"); + } + + @Test + void test7() throws Exception { + + String sql1 = + "SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 AND imp_date >= CURRENT_DATE - " + + "INTERVAL '1 year' AND sys_imp_date < CURRENT_DATE ORDER" + + " BY 总访问次数 LIMIT 10"; + + String sql2 = + "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" + + "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n" + + "`s2_user_department`) AS `t2`\n" + + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" + + "FROM\n" + + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; + + String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1, + Collections.singletonList(sql2), Collections.singletonList("t_1")); + + + Assert.assertEquals(format(mergeSql), + "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM " + + "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`," + + " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) " + + "SELECT COUNT(*) FROM Department JOIN Visits WHERE 总访问次数 > 100 AND imp_date >= " + + "CURRENT_DATE - INTERVAL '1 year' AND sys_imp_date < CURRENT_DATE ORDER BY 总访问次数 " + + "LIMIT 10"); } private static String format(String mergeSql) {