(improvement)[headless] Optimize table name replacement in complex SQL (#1763)

This commit is contained in:
lexluo09
2024-10-09 23:14:18 +08:00
committed by GitHub
parent cadb743eda
commit 1215efbdce
4 changed files with 85 additions and 68 deletions

View File

@@ -64,6 +64,9 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
} }
Column column = (Column) inExpression.getLeftExpression(); Column column = (Column) inExpression.getLeftExpression();
Map<String, String> valueMap = filedNameToValueMap.get(column.getColumnName()); Map<String, String> valueMap = filedNameToValueMap.get(column.getColumnName());
if (!(inExpression.getRightExpression() instanceof ExpressionList)) {
return;
}
ExpressionList rightItemsList = (ExpressionList) inExpression.getRightExpression(); ExpressionList rightItemsList = (ExpressionList) inExpression.getRightExpression();
List<Expression> expressions = rightItemsList.getExpressions(); List<Expression> expressions = rightItemsList.getExpressions();
List<String> values = new ArrayList<>(); List<String> values = new ArrayList<>();

View File

@@ -36,7 +36,9 @@ import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@@ -183,7 +185,6 @@ public class SqlReplaceHelper {
} }
private static void replaceFieldsInPlainOneSelect(Map<String, String> fieldNameMap, private static void replaceFieldsInPlainOneSelect(Map<String, String> fieldNameMap,
boolean exactReplace, PlainSelect plainSelect) { boolean exactReplace, PlainSelect plainSelect) {
// 1. replace where fields // 1. replace where fields
@@ -385,95 +386,89 @@ public class SqlReplaceHelper {
if (StringUtils.isEmpty(tableName)) { if (StringUtils.isEmpty(tableName)) {
return sql; return sql;
} }
List<String> withNameList = SqlSelectHelper.getWithName(sql);
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement); List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
if (!CollectionUtils.isEmpty(plainSelectList)) { if (!CollectionUtils.isEmpty(plainSelectList)) {
List<String> withNameList = SqlSelectHelper.getWithName(sql); plainSelectList.forEach(
plainSelectList.stream().forEach(plainSelect -> { plainSelect -> processPlainSelect(plainSelect, tableName, withNameList));
if (plainSelect.getFromItem() instanceof Table) {
Table table = (Table) plainSelect.getFromItem();
if (!withNameList.contains(table.getName())) {
replaceSingleTable(plainSelect, tableName);
}
}
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect =
(ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
Table table = (Table) subPlainSelect.getFromItem();
if (!withNameList.contains(table.getName())) {
replaceSingleTable(subPlainSelect, tableName);
}
}
});
return selectStatement.toString();
} }
if (selectStatement instanceof PlainSelect) { if (selectStatement instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectStatement; processPlainSelect((PlainSelect) selectStatement, tableName, withNameList);
replaceSingleTable(plainSelect, tableName);
replaceSubTable(plainSelect, tableName);
} else if (selectStatement instanceof SetOperationList) { } else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement; SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList.getSelects().forEach(subSelectBody -> { setOperationList.getSelects()
PlainSelect subPlainSelect = (PlainSelect) subSelectBody; .forEach(subSelectBody -> processPlainSelect((PlainSelect) subSelectBody,
replaceSingleTable(subPlainSelect, tableName); tableName, withNameList));
replaceSubTable(subPlainSelect, tableName);
});
} }
} }
return selectStatement.toString(); return selectStatement.toString();
} }
public static void replaceSubTable(PlainSelect plainSelect, String tableName) { private static void processPlainSelect(PlainSelect plainSelect, String tableName,
if (plainSelect.getFromItem() instanceof ParenthesedSelect) { List<String> withNameList) {
if (plainSelect.getFromItem() instanceof Table) {
replaceSingleTable(plainSelect, tableName, withNameList);
} else if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem(); ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect(); PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
replaceSingleTable(subPlainSelect, tableName); replaceSingleTable(subPlainSelect, tableName, withNameList);
} }
List<Join> joinList = plainSelect.getJoins(); replaceSubTable(plainSelect, tableName, withNameList);
if (CollectionUtils.isEmpty(joinList)) { }
return;
} public static void replaceSingleTable(PlainSelect plainSelect, String tableName,
for (Join join : joinList) { List<String> withNameList) {
if (join.getFromItem() instanceof ParenthesedSelect) { List<PlainSelect> plainSelects =
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) join.getFromItem(); SqlSelectHelper.getPlainSelects(Collections.singletonList(plainSelect));
replaceSingleTable(parenthesedSelect.getPlainSelect(), tableName); plainSelects.forEach(painSelect -> {
painSelect.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
plainSelect.getFromItem().accept(
new TableNameReplaceVisitor(tableName, new HashSet<>(withNameList)));
}
});
replaceJoins(painSelect, tableName, withNameList);
});
}
private static void replaceJoins(PlainSelect plainSelect, String tableName,
List<String> withNameList) {
List<Join> joins = plainSelect.getJoins();
if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) {
if (join.getRightItem() instanceof ParenthesedFromItem) {
List<PlainSelect> subPlainSelects = SqlSelectHelper.getPlainSelects(
Collections.singletonList((PlainSelect) join.getRightItem()));
subPlainSelects.forEach(subPlainSelect -> subPlainSelect.getFromItem().accept(
new TableNameReplaceVisitor(tableName, new HashSet<>(withNameList))));
} else if (join.getRightItem() instanceof Table) {
Table table = (Table) join.getRightItem();
table.setName(tableName);
}
} }
} }
} }
public static void replaceSingleTable(PlainSelect plainSelect, String tableName) { public static void replaceSubTable(PlainSelect plainSelect, String tableName,
// replace table name List<String> withNameList) {
List<PlainSelect> plainSelects = new ArrayList<>(); if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
plainSelects.add(plainSelect); ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
List<PlainSelect> painSelects = SqlSelectHelper.getPlainSelects(plainSelects); replaceSingleTable(parenthesedSelect.getPlainSelect(), tableName, withNameList);
for (PlainSelect painSelect : painSelects) { }
painSelect.accept(new SelectVisitorAdapter() {
@Override List<Join> joinList = plainSelect.getJoins();
public void visit(PlainSelect plainSelect) { if (!CollectionUtils.isEmpty(joinList)) {
plainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); joinList.forEach(join -> {
if (join.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) join.getFromItem();
replaceSingleTable(parenthesedSelect.getPlainSelect(), tableName, withNameList);
} }
}); });
List<Join> joins = painSelect.getJoins();
if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) {
if (join.getRightItem() instanceof ParenthesedFromItem) {
List<PlainSelect> plainSelectList = new ArrayList<>();
plainSelectList.add((PlainSelect) join.getRightItem());
List<PlainSelect> subPlainSelects =
SqlSelectHelper.getPlainSelects(plainSelectList);
for (PlainSelect subPlainSelect : subPlainSelects) {
subPlainSelect.getFromItem()
.accept(new TableNameReplaceVisitor(tableName));
}
} else if (join.getRightItem() instanceof Table) {
Table table = (Table) join.getRightItem();
table.setName(tableName);
}
}
}
} }
} }

View File

@@ -3,16 +3,23 @@ package com.tencent.supersonic.common.jsqlparser;
import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.FromItemVisitorAdapter; import net.sf.jsqlparser.statement.select.FromItemVisitorAdapter;
import java.util.Set;
public class TableNameReplaceVisitor extends FromItemVisitorAdapter { public class TableNameReplaceVisitor extends FromItemVisitorAdapter {
private Set<String> notReplaceTables;
private String tableName; private String tableName;
public TableNameReplaceVisitor(String tableName) { public TableNameReplaceVisitor(String tableName, Set<String> notReplaceTables) {
this.tableName = tableName; this.tableName = tableName;
this.notReplaceTables = notReplaceTables;
} }
@Override @Override
public void visit(Table table) { public void visit(Table table) {
if (notReplaceTables.contains(table.getName())) {
return;
}
table.setName(tableName); table.setName(tableName);
} }
} }

View File

@@ -402,6 +402,18 @@ class SqlReplaceHelperTest {
Assert.assertEquals("SELECT 歌曲名称, sum(评分) FROM cspider WHERE (1 < 2) AND 数据日期 = " Assert.assertEquals("SELECT 歌曲名称, sum(评分) FROM cspider WHERE (1 < 2) AND 数据日期 = "
+ "'2023-10-15' GROUP BY 歌曲名称 HAVING sum(评分) < (SELECT min(评分) " + "'2023-10-15' GROUP BY 歌曲名称 HAVING sum(评分) < (SELECT min(评分) "
+ "FROM cspider WHERE 语种 = '英文')", replaceSql); + "FROM cspider WHERE 语种 = '英文')", replaceSql);
sql = "WITH _部门访问次数_ AS ( SELECT 部门, SUM(访问次数) AS _总访问次数_ FROM 超音数数据集 WHERE 数据日期 >= '2024-07-11'"
+ " AND 数据日期 <= '2024-10-09' GROUP BY 部门 HAVING SUM(访问次数) > 100 ) SELECT 用户, SUM(访问次数) "
+ "AS _访问次数汇总_ FROM 超音数数据集 WHERE 部门 IN ( SELECT 部门 FROM _部门访问次数_ ) AND 数据日期 >= '2024-07-11' "
+ "AND 数据日期 <= '2024-10-09' GROUP BY 用户";
replaceSql = SqlReplaceHelper.replaceTable(sql, "t_1");
Assert.assertEquals("WITH _部门访问次数_ AS (SELECT 部门, SUM(访问次数) AS _总访问次数_ FROM t_1 "
+ "WHERE 数据日期 >= '2024-07-11' AND 数据日期 <= '2024-10-09' GROUP BY 部门 HAVING SUM(访问次数) > 100) "
+ "SELECT 用户, SUM(访问次数) AS _访问次数汇总_ FROM t_1 WHERE 部门 IN (SELECT 部门 FROM _部门访问次数_) "
+ "AND 数据日期 >= '2024-07-11' AND 数据日期 <= '2024-10-09' GROUP BY 用户", replaceSql);
} }
@Test @Test