diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index e009edf2c..53eb01031 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -123,34 +123,34 @@ public class SqlParserSelectHelper { public static List getPlainSelects(PlainSelect plainSelect) { List plainSelects = new ArrayList<>(); plainSelects.add(plainSelect); + + ExpressionVisitorAdapter expressionVisitor = new ExpressionVisitorAdapter() { + @Override + public void visit(SubSelect subSelect) { + SelectBody subSelectBody = subSelect.getSelectBody(); + if (subSelectBody instanceof PlainSelect) { + plainSelects.add((PlainSelect) subSelectBody); + } + } + }; + plainSelect.accept(new SelectVisitorAdapter() { @Override public void visit(PlainSelect plainSelect) { Expression whereExpression = plainSelect.getWhere(); if (whereExpression != null) { - whereExpression.accept(new ExpressionVisitorAdapter() { - @Override - public void visit(SubSelect subSelect) { - SelectBody subSelectBody = subSelect.getSelectBody(); - if (subSelectBody instanceof PlainSelect) { - plainSelects.add((PlainSelect) subSelectBody); - } - } - }); + whereExpression.accept(expressionVisitor); } Expression having = plainSelect.getHaving(); if (Objects.nonNull(having)) { - having.accept(new ExpressionVisitorAdapter() { - @Override - public void visit(SubSelect subSelect) { - SelectBody subSelectBody = subSelect.getSelectBody(); - if (subSelectBody instanceof PlainSelect) { - plainSelects.add((PlainSelect) subSelectBody); - } - } - }); + having.accept(expressionVisitor); + } + List selectItems = plainSelect.getSelectItems(); + if (!CollectionUtils.isEmpty(selectItems)) { + for (SelectItem selectItem : selectItems) { + selectItem.accept(expressionVisitor); + } } - } }); return plainSelects; diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java index 5b71b2920..505066dde 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java @@ -295,6 +295,17 @@ class SqlParserReplaceHelperTest { "SELECT song_name, sum(user_id) FROM CSpider WHERE (1 < 2) AND " + "sys_imp_date = '2023-10-15' GROUP BY song_name HAVING " + "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql); + + replaceSql = "SELECT sum(评分)/ (SELECT sum(评分) FROM CSpider WHERE 数据日期 = '2023-10-15')" + + " FROM CSpider WHERE 数据日期 = '2023-10-15' " + + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider WHERE 语种 = '英文')"; + replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT sum(user_id) / (SELECT sum(user_id) FROM CSpider WHERE sys_imp_date = '2023-10-15') " + + "FROM CSpider WHERE sys_imp_date = '2023-10-15' GROUP BY song_name HAVING " + + "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql); } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index 1ce424e69..6ebb9d555 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -108,6 +108,11 @@ class SqlParserSelectHelperTest { + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); System.out.println(filterExpression); + + filterExpression = SqlParserSelectHelper.getFilterExpression("SELECT sum(销量) / (SELECT sum(销量) FROM 营销月模型 " + + "WHERE MONTH(数据日期) = 9) FROM 营销月模型 WHERE 国家中文名 = '肯尼亚' AND MONTH(数据日期) = 9"); + + System.out.println(filterExpression); } @@ -157,6 +162,10 @@ class SqlParserSelectHelperTest { Assert.assertEquals(allFields.size(), 3); + allFields = SqlParserSelectHelper.getAllFields("SELECT sum(销量) / (SELECT sum(销量) FROM 营销 " + + "WHERE MONTH(数据日期) = 9) FROM 营销 WHERE 国家中文名 = '中国' AND MONTH(数据日期) = 9"); + + Assert.assertEquals(allFields.size(), 3); }