From 500652da368801d7ca60263f9bff82e467b8a91c Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Tue, 10 Oct 2023 16:13:48 +0800 Subject: [PATCH] (improvement)(chat) support add parenthesis and add arenthesis in sys_imp_date (#184) --- .../chat/corrector/WhereCorrector.java | 1 + .../jsqlparser/SqlParserUpdateHelper.java | 48 +++---- .../jsqlparser/SqlParserSelectHelperTest.java | 51 ++++--- .../jsqlparser/SqlParserUpdateHelperTest.java | 124 +++++++++++++++++- 4 files changed, 172 insertions(+), 52 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index 06dc51a5f..0a82eec2c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -72,6 +72,7 @@ public class WhereCorrector extends BaseSemanticCorrector { List whereFields = SqlParserSelectHelper.getWhereFields(sql); if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) { String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); + sql = SqlParserUpdateHelper.addParenthesisToWhere(sql); sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate); } semanticCorrectInfo.setSql(sql); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index fd3db6fe9..0191d9b6c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -9,6 +9,7 @@ import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; @@ -387,17 +388,10 @@ public class SqlParserUpdateHelper { AndExpression andExpression = (AndExpression) whereExpression; Expression leftExpression = andExpression.getLeftExpression(); Expression rightExpression = andExpression.getRightExpression(); - if (isLogicExpression(leftExpression)) { - modifyWhereExpression(leftExpression, fieldNameToAggregate); - } else { - setAggToFunction(leftExpression, fieldNameToAggregate); - } - if (isLogicExpression(rightExpression)) { - modifyWhereExpression(rightExpression, fieldNameToAggregate); - } else { - setAggToFunction(rightExpression, fieldNameToAggregate); - } - setAggToFunction(rightExpression, fieldNameToAggregate); + modifyWhereExpression(leftExpression, fieldNameToAggregate); + modifyWhereExpression(rightExpression, fieldNameToAggregate); + } else if (whereExpression instanceof Parenthesis) { + modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate); } else { setAggToFunction(whereExpression, fieldNameToAggregate); } @@ -515,17 +509,11 @@ public class SqlParserUpdateHelper { AndExpression andExpression = (AndExpression) whereExpression; Expression leftExpression = andExpression.getLeftExpression(); Expression rightExpression = andExpression.getRightExpression(); - if (isLogicExpression(leftExpression)) { - removeWhereExpression(leftExpression, removeFieldNames); - } else { - removeExpressionWithConstant(leftExpression, removeFieldNames); - } - if (isLogicExpression(rightExpression)) { - removeWhereExpression(rightExpression, removeFieldNames); - } else { - removeExpressionWithConstant(rightExpression, removeFieldNames); - } - removeExpressionWithConstant(rightExpression, removeFieldNames); + + removeWhereExpression(leftExpression, removeFieldNames); + removeWhereExpression(rightExpression, removeFieldNames); + } else if (whereExpression instanceof Parenthesis) { + removeWhereExpression(((Parenthesis) whereExpression).getExpression(), removeFieldNames); } else { removeExpressionWithConstant(whereExpression, removeFieldNames); } @@ -577,5 +565,21 @@ public class SqlParserUpdateHelper { } return columnName; } + + public static String addParenthesisToWhere(String sql) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + Expression where = plainSelect.getWhere(); + if (Objects.nonNull(where)) { + Parenthesis parenthesis = new Parenthesis(where); + plainSelect.setWhere(parenthesis); + } + return selectStatement.toString(); + } } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index e17430eda..5fd55974c 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -5,7 +5,6 @@ import java.util.List; import java.util.Map; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -82,6 +81,13 @@ class SqlParserSelectHelperTest { + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); System.out.println(filterExpression); + + filterExpression = SqlParserSelectHelper.getFilterExpression( + "SELECT department, pv FROM s2 WHERE " + + "(user_id like '%alice%' AND publish_date > 10000) and sys_imp_date = '2023-08-08' " + + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); + + System.out.println(filterExpression); } @@ -118,6 +124,12 @@ class SqlParserSelectHelperTest { Assert.assertEquals(allFields.size(), 3); + allFields = SqlParserSelectHelper.getAllFields( + "SELECT department, user_id, field_a FROM s2 WHERE " + + "(user_id = 'alice' AND publish_date = '11') and sys_imp_date " + + "= '2023-08-08' ORDER BY pv DESC LIMIT 1"); + + Assert.assertEquals(allFields.size(), 6); } @@ -150,6 +162,15 @@ class SqlParserSelectHelperTest { Assert.assertEquals(selectFields.contains("发布日期"), true); Assert.assertEquals(selectFields.contains("数据日期"), true); Assert.assertEquals(selectFields.contains("用户"), true); + + sql = "select 部门,用户 from 超音数 where" + + " (用户 = 'alice' and 发布日期 ='11') and 数据日期 = '2023-08-08' " + + "order by 访问次数 limit 1"; + selectFields = SqlParserSelectHelper.getWhereFields(sql); + + Assert.assertEquals(selectFields.contains("发布日期"), true); + Assert.assertEquals(selectFields.contains("数据日期"), true); + Assert.assertEquals(selectFields.contains("用户"), true); } @Test @@ -168,34 +189,6 @@ class SqlParserSelectHelperTest { Assert.assertEquals(selectFields.contains("pv"), true); } - - @Test - void addWhere() throws JSQLParserException { - - String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; - sql = SqlParserUpdateHelper.addWhere(sql, "column_a", 123444555); - List selectFields = SqlParserSelectHelper.getAllFields(sql); - - Assert.assertEquals(selectFields.contains("column_a"), true); - - sql = SqlParserUpdateHelper.addWhere(sql, "column_b", "123456666"); - selectFields = SqlParserSelectHelper.getAllFields(sql); - - Assert.assertEquals(selectFields.contains("column_b"), true); - - Expression expression = CCJSqlParserUtil.parseCondExpression(" ( column_c = 111 or column_d = 1111)"); - - sql = SqlParserUpdateHelper.addWhere( - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1", - expression); - - Assert.assertEquals(sql.contains("column_c = 111"), true); - - } - - @Test void hasAggregateFunction() throws JSQLParserException { diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index 355d21157..ee5bd7e74 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -57,6 +57,17 @@ class SqlParserUpdateHelperTest { + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 数据日期 = '2023-08-09' AND " + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09' " + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap2, false); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " + + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 歌曲发布时 = '2023-08-01') " + + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + } @@ -105,6 +116,17 @@ class SqlParserUpdateHelperTest { + "AND 歌手名 = '林俊杰' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '林俊杰' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09'" + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND " + + "歌手名 = '林俊杰' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); + } @Test @@ -294,8 +316,13 @@ class SqlParserUpdateHelperTest { Assert.assertEquals(sql.contains("column_c = 111"), true); - } + sql = "select 部门,sum (访问次数) from 超音数 where 用户 = alice or 发布日期 ='2023-07-03' group by 部门 limit 1"; + sql = SqlParserUpdateHelper.addParenthesisToWhere(sql); + sql = SqlParserUpdateHelper.addWhere(sql, "数据日期", "2023-08-08"); + Assert.assertEquals(sql, "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "(用户 = alice OR 发布日期 = '2023-07-03') AND 数据日期 = '2023-08-08' GROUP BY 部门 LIMIT 1"); + } @Test void replaceFunctionName() { @@ -319,6 +346,16 @@ class SqlParserUpdateHelperTest { "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" + " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)", replaceSql); + + sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + + " (datediff('month', 数据日期, '2023-09-02') <= 6) and 数据日期 = '2023-10-10' group by MONTH(数据日期)"; + replaceSql = SqlParserUpdateHelper.replaceFunction(sql, functionMap); + + Assert.assertEquals( + "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" + + " (datediff('month', 数据日期, '2023-09-02') <= 6) AND " + + "数据日期 = '2023-10-10' GROUP BY toMonth(数据日期)", + replaceSql); } @Test @@ -332,6 +369,17 @@ class SqlParserUpdateHelperTest { + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", replaceSql); + sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + + "group by 部门 order by 总访问次数 desc limit 10"; + replaceSql = SqlParserUpdateHelper.replaceAlias(sql); + System.out.println(replaceSql); + Assert.assertEquals( + "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " + + "GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", + replaceSql); + } @Test @@ -356,6 +404,16 @@ class SqlParserUpdateHelperTest { + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", replaceSql); + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " + + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') " + + "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + } @Test @@ -380,6 +438,15 @@ class SqlParserUpdateHelperTest { + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", replaceSql); + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " + + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') " + + "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); } @@ -458,6 +525,16 @@ class SqlParserUpdateHelperTest { "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND sum(pv) > 1 " + "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); + + sql = "select department, pv from t_1 where (pv >1 and department = 'HR') " + + " and sys_imp_date = '2023-09-11' GROUP BY department order by pv desc limit 10"; + replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE (sum(pv) > 1 AND department = 'HR') AND " + + "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); } @Test @@ -474,6 +551,16 @@ class SqlParserUpdateHelperTest { "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); + + sql = "select department, sum(pv) from t_1 where (department = 'HR') and sys_imp_date = '2023-09-11' " + + "order by sum(pv) desc limit 10"; + + replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE (department = 'HR') AND sys_imp_date " + + "= '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); } @Test @@ -492,6 +579,16 @@ class SqlParserUpdateHelperTest { "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 " + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", replaceSql); + + sql = "select department, sum(pv) from t_1 where (sum(pv) > 2000) and sys_imp_date = '2023-09-11' " + + "group by department order by sum(pv) desc limit 10"; + + replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE (2 > 1) AND sys_imp_date = '2023-09-11' " + + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); } @Test @@ -520,6 +617,31 @@ class SqlParserUpdateHelperTest { + "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND " + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + sql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋')) and 数据日期 = '2023-08-09' " + + " order by 播放量 desc limit 11"; + replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames); + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 1 IN (1) AND 1 IN (1)) AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", + replaceSql); + } + + + @Test + void addParenthesisToWhere() { + String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + String replaceSql = SqlParserUpdateHelper.addParenthesisToWhere(sql); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01') " + + "ORDER BY 播放量 DESC LIMIT 11", + replaceSql); } private Map initParams() {