(improvement)(chat) support add parenthesis and add arenthesis in sys_imp_date (#184)

This commit is contained in:
lexluo09
2023-10-10 16:13:48 +08:00
committed by GitHub
parent eee39f56a8
commit 500652da36
4 changed files with 172 additions and 52 deletions

View File

@@ -72,6 +72,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addParenthesisToWhere(sql);
sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
}
semanticCorrectInfo.setSql(sql);

View File

@@ -9,6 +9,7 @@ import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
@@ -387,17 +388,10 @@ public class SqlParserUpdateHelper {
AndExpression andExpression = (AndExpression) whereExpression;
Expression leftExpression = andExpression.getLeftExpression();
Expression rightExpression = andExpression.getRightExpression();
if (isLogicExpression(leftExpression)) {
modifyWhereExpression(leftExpression, fieldNameToAggregate);
} else {
setAggToFunction(leftExpression, fieldNameToAggregate);
}
if (isLogicExpression(rightExpression)) {
modifyWhereExpression(rightExpression, fieldNameToAggregate);
} else {
setAggToFunction(rightExpression, fieldNameToAggregate);
}
setAggToFunction(rightExpression, fieldNameToAggregate);
} else if (whereExpression instanceof Parenthesis) {
modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate);
} else {
setAggToFunction(whereExpression, fieldNameToAggregate);
}
@@ -515,17 +509,11 @@ public class SqlParserUpdateHelper {
AndExpression andExpression = (AndExpression) whereExpression;
Expression leftExpression = andExpression.getLeftExpression();
Expression rightExpression = andExpression.getRightExpression();
if (isLogicExpression(leftExpression)) {
removeWhereExpression(leftExpression, removeFieldNames);
} else {
removeExpressionWithConstant(leftExpression, removeFieldNames);
}
if (isLogicExpression(rightExpression)) {
removeWhereExpression(rightExpression, removeFieldNames);
} else {
removeExpressionWithConstant(rightExpression, removeFieldNames);
}
removeExpressionWithConstant(rightExpression, removeFieldNames);
} else if (whereExpression instanceof Parenthesis) {
removeWhereExpression(((Parenthesis) whereExpression).getExpression(), removeFieldNames);
} else {
removeExpressionWithConstant(whereExpression, removeFieldNames);
}
@@ -577,5 +565,21 @@ public class SqlParserUpdateHelper {
}
return columnName;
}
public static String addParenthesisToWhere(String sql) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectBody;
Expression where = plainSelect.getWhere();
if (Objects.nonNull(where)) {
Parenthesis parenthesis = new Parenthesis(where);
plainSelect.setWhere(parenthesis);
}
return selectStatement.toString();
}
}

View File

@@ -5,7 +5,6 @@ import java.util.List;
import java.util.Map;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
@@ -82,6 +81,13 @@ class SqlParserSelectHelperTest {
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
System.out.println(filterExpression);
filterExpression = SqlParserSelectHelper.getFilterExpression(
"SELECT department, pv FROM s2 WHERE "
+ "(user_id like '%alice%' AND publish_date > 10000) and sys_imp_date = '2023-08-08' "
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
System.out.println(filterExpression);
}
@@ -118,6 +124,12 @@ class SqlParserSelectHelperTest {
Assert.assertEquals(allFields.size(), 3);
allFields = SqlParserSelectHelper.getAllFields(
"SELECT department, user_id, field_a FROM s2 WHERE "
+ "(user_id = 'alice' AND publish_date = '11') and sys_imp_date "
+ "= '2023-08-08' ORDER BY pv DESC LIMIT 1");
Assert.assertEquals(allFields.size(), 6);
}
@@ -150,6 +162,15 @@ class SqlParserSelectHelperTest {
Assert.assertEquals(selectFields.contains("发布日期"), true);
Assert.assertEquals(selectFields.contains("数据日期"), true);
Assert.assertEquals(selectFields.contains("用户"), true);
sql = "select 部门,用户 from 超音数 where"
+ " (用户 = 'alice' and 发布日期 ='11') and 数据日期 = '2023-08-08' "
+ "order by 访问次数 limit 1";
selectFields = SqlParserSelectHelper.getWhereFields(sql);
Assert.assertEquals(selectFields.contains("发布日期"), true);
Assert.assertEquals(selectFields.contains("数据日期"), true);
Assert.assertEquals(selectFields.contains("用户"), true);
}
@Test
@@ -168,34 +189,6 @@ class SqlParserSelectHelperTest {
Assert.assertEquals(selectFields.contains("pv"), true);
}
@Test
void addWhere() throws JSQLParserException {
String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' "
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1";
sql = SqlParserUpdateHelper.addWhere(sql, "column_a", 123444555);
List<String> selectFields = SqlParserSelectHelper.getAllFields(sql);
Assert.assertEquals(selectFields.contains("column_a"), true);
sql = SqlParserUpdateHelper.addWhere(sql, "column_b", "123456666");
selectFields = SqlParserSelectHelper.getAllFields(sql);
Assert.assertEquals(selectFields.contains("column_b"), true);
Expression expression = CCJSqlParserUtil.parseCondExpression(" ( column_c = 111 or column_d = 1111)");
sql = SqlParserUpdateHelper.addWhere(
"select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' "
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1",
expression);
Assert.assertEquals(sql.contains("column_c = 111"), true);
}
@Test
void hasAggregateFunction() throws JSQLParserException {

View File

@@ -57,6 +57,17 @@ class SqlParserUpdateHelperTest {
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 数据日期 = '2023-08-09' AND "
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09' "
+ " order by 播放量 desc limit 11";
replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' "
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 歌曲发布时 = '2023-08-01') "
+ "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
}
@@ -105,6 +116,17 @@ class SqlParserUpdateHelperTest {
+ "AND 歌手名 = '林俊杰' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = "
+ "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌手名 = '林俊杰' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09'"
+ " order by 播放量 desc limit 11";
replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND "
+ "歌手名 = '林俊杰' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' "
+ "ORDER BY 播放量 DESC LIMIT 11", replaceSql);
}
@Test
@@ -294,8 +316,13 @@ class SqlParserUpdateHelperTest {
Assert.assertEquals(sql.contains("column_c = 111"), true);
}
sql = "select 部门,sum (访问次数) from 超音数 where 用户 = alice or 发布日期 ='2023-07-03' group by 部门 limit 1";
sql = SqlParserUpdateHelper.addParenthesisToWhere(sql);
sql = SqlParserUpdateHelper.addWhere(sql, "数据日期", "2023-08-08");
Assert.assertEquals(sql, "SELECT 部门, sum(访问次数) FROM 超音数 WHERE "
+ "(用户 = alice OR 发布日期 = '2023-07-03') AND 数据日期 = '2023-08-08' GROUP BY 部门 LIMIT 1");
}
@Test
void replaceFunctionName() {
@@ -319,6 +346,16 @@ class SqlParserUpdateHelperTest {
"SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE"
+ " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)",
replaceSql);
sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where"
+ " (datediff('month', 数据日期, '2023-09-02') <= 6) and 数据日期 = '2023-10-10' group by MONTH(数据日期)";
replaceSql = SqlParserUpdateHelper.replaceFunction(sql, functionMap);
Assert.assertEquals(
"SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE"
+ " (datediff('month', 数据日期, '2023-09-02') <= 6) AND "
+ "数据日期 = '2023-10-10' GROUP BY toMonth(数据日期)",
replaceSql);
}
@Test
@@ -332,6 +369,17 @@ class SqlParserUpdateHelperTest {
+ "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10",
replaceSql);
sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' "
+ "group by 部门 order by 总访问次数 desc limit 10";
replaceSql = SqlParserUpdateHelper.replaceAlias(sql);
System.out.println(replaceSql);
Assert.assertEquals(
"SELECT 部门, sum(访问次数) FROM 超音数 WHERE "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' "
+ "GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10",
replaceSql);
}
@Test
@@ -356,6 +404,16 @@ class SqlParserUpdateHelperTest {
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
replaceSql);
sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND "
+ "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
System.out.println(replaceSql);
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') "
+ "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
replaceSql);
}
@Test
@@ -380,6 +438,15 @@ class SqlParserUpdateHelperTest {
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
replaceSql);
sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND "
+ "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
System.out.println(replaceSql);
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') "
+ "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
replaceSql);
}
@@ -458,6 +525,16 @@ class SqlParserUpdateHelperTest {
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND sum(pv) > 1 "
+ "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, pv from t_1 where (pv >1 and department = 'HR') "
+ " and sys_imp_date = '2023-09-11' GROUP BY department order by pv desc limit 10";
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE (sum(pv) > 1 AND department = 'HR') AND "
+ "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
}
@Test
@@ -474,6 +551,16 @@ class SqlParserUpdateHelperTest {
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, sum(pv) from t_1 where (department = 'HR') and sys_imp_date = '2023-09-11' "
+ "order by sum(pv) desc limit 10";
replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE (department = 'HR') AND sys_imp_date "
+ "= '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
}
@Test
@@ -492,6 +579,16 @@ class SqlParserUpdateHelperTest {
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 "
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, sum(pv) from t_1 where (sum(pv) > 2000) and sys_imp_date = '2023-09-11' "
+ "group by department order by sum(pv) desc limit 10";
replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE (2 > 1) AND sys_imp_date = '2023-09-11' "
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
}
@Test
@@ -520,6 +617,31 @@ class SqlParserUpdateHelperTest {
+ "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND "
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
sql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋')) and 数据日期 = '2023-08-09' "
+ " order by 播放量 desc limit 11";
replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "AND 1 IN (1) AND 1 IN (1)) AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
}
@Test
void addParenthesisToWhere() {
String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
+ " order by 播放量 desc limit 11";
String replaceSql = SqlParserUpdateHelper.addParenthesisToWhere(sql);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "AND 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01') "
+ "ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
}
private Map<String, String> initParams() {