diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java index dd442722c..7fc4babb2 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java @@ -1,19 +1,15 @@ package com.tencent.supersonic.common.jsqlparser; import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.GroupByVisitor; -import org.apache.commons.lang3.StringUtils; import java.util.List; import java.util.Map; -import java.util.Objects; @Slf4j public class GroupByReplaceVisitor implements GroupByVisitor { @@ -33,45 +29,15 @@ public class GroupByReplaceVisitor implements GroupByVisitor { for (int i = 0; i < groupByExpressions.size(); i++) { Expression expression = groupByExpressions.get(i); - String columnName = getColumnName(expression); - - String replaceColumn = - parseVisitorHelper.getReplaceValue(columnName, fieldNameMap, exactReplace); - if (StringUtils.isNotEmpty(replaceColumn)) { - replaceExpression(groupByExpressions, i, expression, replaceColumn); - } + replaceExpression(expression); } } - private String getColumnName(Expression expression) { - if (expression instanceof Function) { - Function function = (Function) expression; - if (Objects.nonNull(function.getParameters().getExpressions().get(0))) { - return function.getParameters().getExpressions().get(0).toString(); - } - } - return expression.toString(); - } - - private void replaceExpression(List groupByExpressions, int index, - Expression expression, String replaceColumn) { + private void replaceExpression(Expression expression) { if (expression instanceof Column) { - groupByExpressions.set(index, new Column(replaceColumn)); + parseVisitorHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace); } else if (expression instanceof Function) { - try { - Expression newExpression = CCJSqlParserUtil.parseExpression(replaceColumn); - ExpressionList newExpressionList = new ExpressionList<>(); - newExpressionList.add(newExpression); - - Function function = (Function) expression; - if (function.getParameters().size() > 1) { - function.getParameters().stream().skip(1) - .forEach(e -> newExpressionList.add((Function) e)); - } - function.setParameters(newExpressionList); - } catch (JSQLParserException e) { - log.error("Error parsing expression: {}", replaceColumn, e); - } + parseVisitorHelper.replaceFunction((Function) expression, fieldNameMap, exactReplace); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/OrderByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/OrderByReplaceVisitor.java index 58b5466a4..0af48f49f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/OrderByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/OrderByReplaceVisitor.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.common.jsqlparser; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; -import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.OrderByVisitorAdapter; @@ -27,15 +26,9 @@ public class OrderByReplaceVisitor extends OrderByVisitorAdapter { parseVisitorHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace); } if (expression instanceof Function) { - Function function = (Function) expression; - // List expressions = function.getParameters().getExpressions(); - ExpressionList expressions = function.getParameters(); - for (Expression column : expressions) { - if (column instanceof Column) { - parseVisitorHelper.replaceColumn((Column) column, fieldNameMap, exactReplace); - } - } + parseVisitorHelper.replaceFunction((Function) expression, fieldNameMap, exactReplace); } super.visit(orderBy); } + } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ParseVisitorHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ParseVisitorHelper.java index 37297a00d..2ab3d00eb 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ParseVisitorHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ParseVisitorHelper.java @@ -2,6 +2,9 @@ package com.tencent.supersonic.common.jsqlparser; import com.tencent.supersonic.common.util.StringUtil; import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.schema.Column; import org.apache.commons.lang3.StringUtils; @@ -13,6 +16,17 @@ import java.util.stream.Collectors; @Slf4j public class ParseVisitorHelper { + public void replaceFunction(Function expression, Map fieldNameMap, + boolean exactReplace) { + Function function = expression; + ExpressionList expressions = function.getParameters(); + for (Expression column : expressions) { + if (column instanceof Column) { + replaceColumn((Column) column, fieldNameMap, exactReplace); + } + } + } + public void replaceColumn(Column column, Map fieldNameMap, boolean exactReplace) { String columnName = StringUtil.replaceBackticks(column.getColumnName()); diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java index a1799cdbe..0175e90ed 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java @@ -11,9 +11,10 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; -/** SqlParserReplaceHelperTest */ +/** + * SqlParserReplaceHelperTest + */ class SqlReplaceHelperTest { - @Test void testReplaceAggField() { String sql = "SELECT 维度1,sum(播放量) FROM 数据库 " @@ -178,188 +179,6 @@ class SqlReplaceHelperTest { replaceSql); } - @Test - void testReplaceFields() { - - Map fieldToBizName = initParams(); - String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE (publish_date >= '2023-08-08' AND publish_date <= '2023-08-09')" - + " AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND song_publis_date = '2023-08-01'" - + " ORDER BY play_count DESC LIMIT 11", - replaceSql); - - replaceSql = - "select 品牌名称 from 互联网企业 where datediff('year', 品牌成立时间, '2023-11-04') > 17 and 注册资本 = 50000000"; - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - replaceSql = SqlRemoveHelper.removeNumberFilter(replaceSql); - Assert.assertEquals( - "SELECT 品牌名称 FROM 互联网企业 WHERE 品牌成立时间 < '2006-11-04' AND 注册资本 = 50000000", - replaceSql); - - replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " - + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " - + "group by MONTH(数据日期) order by sum(访问次数) desc limit 1"; - - replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT MONTH(sys_imp_date), sum(pv) FROM 内容库产品 WHERE (sys_imp_date >= '2023-03-03' " - + "AND sys_imp_date <= '2023-09-03')" - + " GROUP BY MONTH(sys_imp_date) ORDER BY sum(pv) DESC LIMIT 1", - replaceSql); - - replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " - + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " - + "group by MONTH(数据日期) HAVING sum(访问次数) > 1000"; - - replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT MONTH(sys_imp_date), sum(pv) FROM 内容库产品 WHERE (sys_imp_date >= '2023-03-03' AND" - + " sys_imp_date <= '2023-09-03') GROUP BY MONTH(sys_imp_date) HAVING sum(pv) > 1000", - replaceSql); - - replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) " - + "in (2022, 2023) and 数据日期 = '2023-08-14' group by YEAR(发行日期)"; - - replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals("SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 " - + "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14' " - + "GROUP BY YEAR(publish_date)", replaceSql); - - replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 " - + "where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' " + "group by 发行日期"; - - replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals("SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 " - + "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14'" - + " GROUP BY publish_date", replaceSql); - - replaceSql = SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-11') <= 1 " - + "and 结算播放量 > 1000000 and datediff('day', 数据日期, '2023-08-11') <= 30", - fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE (publish_date >= '2022-08-11' " - + "AND publish_date <= '2023-08-11') AND play_count > 1000000 AND " - + "(sys_imp_date >= '2023-07-12' AND sys_imp_date <= '2023-08-11')", - replaceSql); - - replaceSql = SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE (publish_date >= '2023-08-08' AND publish_date <= '2023-08-09')" - + " AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' ORDER BY play_count DESC LIMIT 11", - replaceSql); - - replaceSql = SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') = 0 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE (publish_date >= '2023-01-01' AND publish_date <= '2023-08-09')" - + " AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' ORDER BY play_count DESC LIMIT 11", - replaceSql); - - replaceSql = SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') <= 0.5 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE (publish_date >= '2023-02-09' AND publish_date <= '2023-08-09')" - + " AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' ORDER BY play_count DESC LIMIT 11", - replaceSql); - - replaceSql = SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - replaceSql = SqlRemoveHelper.removeNumberFilter(replaceSql); - Assert.assertEquals("SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-02-09' AND" - + " singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09'" - + " ORDER BY play_count DESC LIMIT 11", replaceSql); - - replaceSql = SqlReplaceHelper - .replaceFields("select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice'" - + " and 发布日期 ='11' order by 访问次数 desc limit 1", fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT department, user_id FROM 超音数 WHERE sys_imp_date = '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1", - replaceSql); - - replaceSql = SqlReplaceHelper.replaceTable(replaceSql, "s2"); - - replaceSql = - SqlAddHelper.addFieldsToSelect(replaceSql, Collections.singletonList("field_a")); - - replaceSql = - SqlReplaceHelper.replaceFields( - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 ='alice' and 发布日期 ='11' group by 部门 limit 1", - fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals("SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11' GROUP BY department LIMIT 1", - replaceSql); - - replaceSql = "select sum(访问次数) from 超音数 where 数据日期 >= '2023-08-06' " - + "and 数据日期 <= '2023-08-06' and 部门 = 'hr'"; - replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals("SELECT sum(pv) FROM 超音数 WHERE sys_imp_date >= '2023-08-06' " - + "AND sys_imp_date <= '2023-08-06' AND department = 'hr'", replaceSql); - - replaceSql = "SELECT 歌曲名称, sum(评分) FROM CSpider WHERE(1 < 2) AND 数据日期 = '2023-10-15' " - + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider WHERE 语种 = '英文')"; - replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name, sum(评分) FROM CSpider WHERE (1 < 2) AND " - + "sys_imp_date = '2023-10-15' GROUP BY song_name HAVING " - + "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", - replaceSql); - - replaceSql = "SELECT sum(评分)/ (SELECT sum(评分) FROM CSpider WHERE 数据日期 = '2023-10-15')" - + " FROM CSpider WHERE 数据日期 = '2023-10-15' " - + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider WHERE 语种 = '英文')"; - replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT sum(评分) / (SELECT sum(评分) FROM CSpider WHERE sys_imp_date = '2023-10-15') " - + "FROM CSpider WHERE sys_imp_date = '2023-10-15' GROUP BY song_name HAVING " - + "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", - replaceSql); - } - @Test void testReplaceFunctionField() { Map fieldToBizName = initParams(); @@ -500,7 +319,7 @@ class SqlReplaceHelperTest { replaceSql); } - private Map initParams() { + protected Map initParams() { Map fieldToBizName = new HashMap<>(); fieldToBizName.put("部门", "department"); fieldToBizName.put("用户", "user_id");