[improvement][headless] Support merging SQL queries with WITH using the LIMIT clause (#1795)

This commit is contained in:
lexluo09
2024-10-13 21:37:34 +08:00
committed by GitHub
parent 6d993b4785
commit 28b7847389
2 changed files with 135 additions and 20 deletions

View File

@@ -6,6 +6,8 @@ import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOrderBy;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlWith;
import org.apache.calcite.sql.SqlWithItem;
import org.apache.calcite.sql.SqlWriterConfig;
@@ -42,30 +44,46 @@ public class SqlMergeWithUtils {
// Create a new WITH item for parentWithName without quotes
SqlWithItem withItem = new SqlWithItem(SqlParserPos.ZERO,
new SqlIdentifier(parentWithName, SqlParserPos.ZERO), // false
// to
// avoid
// quotes
null, sqlNode2, SqlLiteral.createBoolean(false, SqlParserPos.ZERO));
new SqlIdentifier(parentWithName, SqlParserPos.ZERO), null, sqlNode2,
SqlLiteral.createBoolean(false, SqlParserPos.ZERO));
// Add the new WITH item to the list
withItemList.add(withItem);
}
// Check if the main SQL node contains a LIMIT clause
SqlNode limitNode = null;
if (sqlNode1 instanceof SqlOrderBy) {
SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1;
limitNode = sqlOrderBy.fetch;
sqlNode1 = sqlOrderBy.query;
} else if (sqlNode1 instanceof SqlSelect) {
SqlSelect sqlSelect = (SqlSelect) sqlNode1;
limitNode = sqlSelect.getFetch();
sqlSelect.setFetch(null);
sqlNode1 = sqlSelect;
}
// Extract existing WITH items from sqlNode1 if it is a SqlWith
if (sqlNode1 instanceof SqlWith) {
SqlWith sqlWith = (SqlWith) sqlNode1;
withItemList.addAll(sqlWith.withList.getList());
sqlNode1 = sqlWith.body;
}
// Create a new SqlWith node
SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO,
new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1);
// If there was a LIMIT clause, wrap the finalSqlNode in a SqlOrderBy with the LIMIT
SqlNode resultNode = finalSqlNode;
if (limitNode != null) {
resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode, SqlNodeList.EMPTY, null,
limitNode);
}
// Custom SqlPrettyWriter configuration to avoid quoting identifiers
SqlWriterConfig config = Configuration.getSqlWriterConfig(engineType);
// Pretty print the final SQL
SqlPrettyWriter writer = new SqlPrettyWriter(config);
return writer.format(finalSqlNode);
return writer.format(resultNode);
}
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.calcite;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import java.util.Collections;
@@ -11,8 +12,7 @@ import java.util.Collections;
class SqlWithMergerTest {
@Test
void testWithMerger() throws SqlParseException {
void test1() throws SqlParseException {
String sql1 = "WITH DepartmentVisits AS (\n" + " SELECT department, SUM(pv) AS 总访问次数\n"
+ " FROM t_1\n"
+ " WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29'\n"
@@ -28,20 +28,117 @@ class SqlWithMergerTest {
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
Collections.singletonList(sql2), Collections.singletonList("t_1"));
System.out.println(mergeSql);
Assert.assertEquals(format(mergeSql),
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` "
+ "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN "
+ "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) "
+ "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`), DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 "
+ "FROM t_1 WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29' GROUP BY department) "
+ "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100");
}
sql1 = "WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' "
+ "AND sys_imp_date <= '2024-09-28' GROUP BY department) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000";
@Test
void test2() throws SqlParseException {
sql2 = "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
+ "`s2_user_department`) AS `t2`\n"
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
+ "FROM\n" + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
String sql1 =
"WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' "
+ "AND sys_imp_date <= '2024-09-28' GROUP BY department) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000";
mergeSql = SqlMergeWithUtils.mergeWith(EngineType.H2, sql1, Collections.singletonList(sql2),
Collections.singletonList("t_1"));
String sql2 =
"SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
+ "`s2_user_department`) AS `t2`\n"
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
+ "FROM\n"
+ "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
System.out.println(mergeSql);
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
Collections.singletonList(sql2), Collections.singletonList("t_1"));
Assert.assertEquals(format(mergeSql),
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` "
+ "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN "
+ "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON "
+ "`t2`.`user_name` = `t3`.`user_name`), DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 "
+ "WHERE sys_imp_date >= '2024-08-28' AND sys_imp_date <= '2024-09-28' GROUP BY department) SELECT COUNT(*) "
+ "FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000");
}
@Test
void test3() throws SqlParseException {
String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000";
String sql2 =
"SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
+ "`s2_user_department`) AS `t2`\n"
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
+ "FROM\n"
+ "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
Collections.singletonList(sql2), Collections.singletonList("t_1"));
Assert.assertEquals(format(mergeSql),
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` "
+ "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN "
+ "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) "
+ "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 "
+ "LIMIT 1000");
}
@Test
void test4() throws SqlParseException {
String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100";
String sql2 =
"SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
+ "`s2_user_department`) AS `t2`\n"
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
+ "FROM\n"
+ "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
Collections.singletonList(sql2), Collections.singletonList("t_1"));
Assert.assertEquals(format(mergeSql),
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` "
+ "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` "
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) "
+ "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100");
}
@Test
void test5() throws SqlParseException {
String sql1 = "SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100";
String sql2 =
"SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
+ "`s2_user_department`) AS `t2`\n"
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
+ "FROM\n"
+ "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
Collections.singletonList(sql2), Collections.singletonList("t_1"));
Assert.assertEquals(format(mergeSql),
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` "
+ "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN "
+ "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) "
+ "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM Department INNER JOIN Visits "
+ "WHERE 总访问次数 > 100");
}
private static String format(String mergeSql) {
mergeSql = mergeSql.replace("\r\n", "\n");
// Remove extra spaces and newlines
return mergeSql.replaceAll("\\s+", " ").trim();
}
}