[improvement][common] Field replacement is performed using a recursive approach, and it supports field replacement with complex expressions (#1859)

This commit is contained in:
lexluo09
2024-10-30 13:17:28 +08:00
committed by GitHub
parent 9644fc4207
commit 5c70607851
5 changed files with 161 additions and 40 deletions

View File

@@ -2,9 +2,13 @@ package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.AnalyticExpression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.WindowDefinition;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.OrderByElement;
import org.springframework.util.CollectionUtils;
import java.util.Map; import java.util.Map;
@@ -34,4 +38,16 @@ public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
exactReplace.set(originalExactReplace); exactReplace.set(originalExactReplace);
} }
} }
@Override
public void visit(AnalyticExpression expr) {
super.visit(expr);
WindowDefinition windowDefinition = expr.getWindowDefinition();
if (windowDefinition != null
&& !CollectionUtils.isEmpty(windowDefinition.getOrderByElements())) {
for (OrderByElement element : windowDefinition.getOrderByElements()) {
element.getExpression().accept(this);
}
}
}
} }

View File

@@ -154,37 +154,20 @@ public class SqlReplaceHelper {
public static String replaceFields(String sql, Map<String, String> fieldNameMap, public static String replaceFields(String sql, Map<String, String> fieldNameMap,
boolean exactReplace) { boolean exactReplace) {
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement); Set<Select> plainSelectList = SqlSelectHelper.getAllSelect(selectStatement);
if (selectStatement instanceof PlainSelect) { for (Select plainSelect : plainSelectList) {
PlainSelect plainSelect = (PlainSelect) selectStatement; if (plainSelect instanceof PlainSelect) {
plainSelectList.add(plainSelect); replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace,
getFromSelect(plainSelect.getFromItem(), plainSelectList); (PlainSelect) plainSelect);
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
});
} }
List<OrderByElement> orderByElements = setOperationList.getOrderByElements(); if (plainSelect instanceof SetOperationList) {
if (!CollectionUtils.isEmpty(orderByElements)) { replaceFieldsInSetOperationList(fieldNameMap, exactReplace,
for (OrderByElement orderByElement : orderByElements) { (SetOperationList) plainSelect);
orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace));
}
} }
} else {
return sql;
}
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelects(plainSelectList);
for (PlainSelect plainSelect : plainSelects) {
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, plainSelect);
} }
return selectStatement.toString(); return selectStatement.toString();
} }
private static void replaceFieldsInPlainOneSelect(Map<String, String> fieldNameMap, private static void replaceFieldsInPlainOneSelect(Map<String, String> fieldNameMap,
boolean exactReplace, PlainSelect plainSelect) { boolean exactReplace, PlainSelect plainSelect) {
// 1. replace where fields // 1. replace where fields
@@ -236,22 +219,24 @@ public class SqlReplaceHelper {
List<Join> joins = plainSelect.getJoins(); List<Join> joins = plainSelect.getJoins();
if (!CollectionUtils.isEmpty(joins)) { if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) { for (Join join : joins) {
if (!CollectionUtils.isEmpty(join.getOnExpressions())) { if (CollectionUtils.isEmpty(join.getOnExpressions())) {
join.getOnExpressions().stream().forEach(onExpression -> {
onExpression.accept(visitor);
});
}
if (!(join.getRightItem() instanceof ParenthesedSelect)) {
continue; continue;
} }
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) join.getRightItem(); join.getOnExpressions().stream().forEach(onExpression -> {
List<PlainSelect> plainSelectList = new ArrayList<>(); onExpression.accept(visitor);
plainSelectList.add(parenthesedSelect.getPlainSelect()); });
List<PlainSelect> subPlainSelects = }
SqlSelectHelper.getPlainSelects(plainSelectList); }
for (PlainSelect subPlainSelect : subPlainSelects) { }
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect);
}
private static void replaceFieldsInSetOperationList(Map<String, String> fieldNameMap,
boolean exactReplace, SetOperationList operationList) {
List<OrderByElement> orderByElements = operationList.getOrderByElements();
if (!CollectionUtils.isEmpty(orderByElements)) {
for (OrderByElement orderByElement : orderByElements) {
orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace));
} }
} }
} }

View File

@@ -26,6 +26,7 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement; import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.LateralView; import net.sf.jsqlparser.statement.select.LateralView;
@@ -50,7 +51,9 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** Sql Parser Select Helper */ /**
* Sql Parser Select Helper
*/
@Slf4j @Slf4j
public class SqlSelectHelper { public class SqlSelectHelper {
@@ -808,4 +811,103 @@ public class SqlSelectHelper {
} }
} }
} }
public static Set<Select> getAllSelect(Select selectStatement) {
Set<Select> selects = new HashSet<>();
collectSelects(selectStatement, selects);
return selects;
}
private static void collectSelects(Select select, Set<Select> selects) {
if (select == null) {
return;
}
if (select instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) select;
selects.add(plainSelect);
collectFromItemPlainSelects(plainSelect.getFromItem(), selects);
collectWithItemPlainSelects(plainSelect.getWithItemsList(), selects);
collectJoinsPlainSelects(plainSelect.getJoins(), selects);
collectNestedPlainSelects(plainSelect, selects);
} else if (select instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) select;
selects.add(setOperationList);
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
for (Select subSelectBody : setOperationList.getSelects()) {
collectSelects(subSelectBody, selects);
}
}
} else if (select instanceof WithItem) {
WithItem withItem = (WithItem) select;
collectSelects(withItem.getSelect(), selects);
} else if (select instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) select;
collectSelects(parenthesedSelect.getPlainSelect(), selects);
}
}
private static void collectJoinsPlainSelects(List<Join> joins, Set<Select> selects) {
if (CollectionUtils.isEmpty(joins)) {
return;
}
for (Join join : joins) {
FromItem rightItem = join.getRightItem();
if (!(rightItem instanceof ParenthesedSelect)) {
continue;
}
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) rightItem;
selects.add(parenthesedSelect.getPlainSelect());
}
}
private static void collectFromItemPlainSelects(FromItem fromItem, Set<Select> selects) {
if (fromItem instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) fromItem;
collectSelects(parenthesedSelect.getSelect(), selects);
}
}
public static void collectWithItemPlainSelects(List<WithItem> withItemList,
Set<Select> selects) {
if (CollectionUtils.isEmpty(withItemList)) {
return;
}
for (WithItem withItem : withItemList) {
collectSelects(withItem.getSelect(), selects);
}
}
private static void collectNestedPlainSelects(PlainSelect plainSelect, Set<Select> selects) {
ExpressionVisitorAdapter expressionVisitor = new ExpressionVisitorAdapter() {
@Override
public void visit(Select subSelect) {
if (subSelect instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) subSelect;
if (parenthesedSelect.getSelect() instanceof PlainSelect) {
selects.add(parenthesedSelect.getPlainSelect());
}
}
}
};
plainSelect.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
Expression whereExpression = plainSelect.getWhere();
if (whereExpression != null) {
whereExpression.accept(expressionVisitor);
}
Expression having = plainSelect.getHaving();
if (Objects.nonNull(having)) {
having.accept(expressionVisitor);
}
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
if (!CollectionUtils.isEmpty(selectItems)) {
for (SelectItem selectItem : selectItems) {
selectItem.accept(expressionVisitor);
}
}
}
});
}
} }

View File

@@ -263,4 +263,21 @@ class SqlReplaceFieldsTest extends SqlReplaceHelperTest {
+ "SELECT * FROM daily_visits", replaceSql); + "SELECT * FROM daily_visits", replaceSql);
} }
@Test
void testReplaceFields18() {
String replaceSql = "WITH\n" + " latest_data AS (\n" + " SELECT\n" + " 粉丝数,\n"
+ " ROW_NUMBER() OVER (\n" + " ORDER BY\n" + " 数据日期 DESC\n"
+ " ) AS __row_num__\n" + " FROM\n" + " 问答艺人数据集\n" + " WHERE\n"
+ " (TME歌手ID = '1')\n" + " AND (\n" + " 数据日期 >= '2024-10-22'\n"
+ " AND 数据日期 <= '2024-10-29'\n" + " )\n" + " )\n" + "SELECT\n"
+ " AVG(__粉丝数__)\n" + "FROM\n" + " latest_data\n" + "WHERE\n"
+ " __row_num__ = 1";
replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName);
Assert.assertEquals("WITH latest_data AS (SELECT fans_cnt, ROW_NUMBER() OVER "
+ "(ORDER BY sys_imp_date DESC) AS __row_num__ FROM 问答艺人数据集 WHERE (TME歌手ID = '1') "
+ "AND (sys_imp_date >= '2024-10-22' AND sys_imp_date <= '2024-10-29')) SELECT AVG(__粉丝数__) "
+ "FROM latest_data WHERE __row_num__ = 1", replaceSql);
}
} }

View File

@@ -348,6 +348,7 @@ class SqlReplaceHelperTest {
fieldToBizName.put("歌曲发布时间", "song_publis_date"); fieldToBizName.put("歌曲发布时间", "song_publis_date");
fieldToBizName.put("歌曲发布年份", "song_publis_year"); fieldToBizName.put("歌曲发布年份", "song_publis_year");
fieldToBizName.put("访问次数", "pv"); fieldToBizName.put("访问次数", "pv");
fieldToBizName.put("粉丝数", "fans_cnt");
return fieldToBizName; return fieldToBizName;
} }