mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) support add parenthesis and add arenthesis in sys_imp_date (#184)
This commit is contained in:
@@ -72,6 +72,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
|
||||
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
|
||||
sql = SqlParserUpdateHelper.addParenthesisToWhere(sql);
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
|
||||
}
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
|
||||
@@ -9,6 +9,7 @@ import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.Parenthesis;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
@@ -387,17 +388,10 @@ public class SqlParserUpdateHelper {
|
||||
AndExpression andExpression = (AndExpression) whereExpression;
|
||||
Expression leftExpression = andExpression.getLeftExpression();
|
||||
Expression rightExpression = andExpression.getRightExpression();
|
||||
if (isLogicExpression(leftExpression)) {
|
||||
modifyWhereExpression(leftExpression, fieldNameToAggregate);
|
||||
} else {
|
||||
setAggToFunction(leftExpression, fieldNameToAggregate);
|
||||
}
|
||||
if (isLogicExpression(rightExpression)) {
|
||||
modifyWhereExpression(rightExpression, fieldNameToAggregate);
|
||||
} else {
|
||||
setAggToFunction(rightExpression, fieldNameToAggregate);
|
||||
}
|
||||
setAggToFunction(rightExpression, fieldNameToAggregate);
|
||||
} else if (whereExpression instanceof Parenthesis) {
|
||||
modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate);
|
||||
} else {
|
||||
setAggToFunction(whereExpression, fieldNameToAggregate);
|
||||
}
|
||||
@@ -515,17 +509,11 @@ public class SqlParserUpdateHelper {
|
||||
AndExpression andExpression = (AndExpression) whereExpression;
|
||||
Expression leftExpression = andExpression.getLeftExpression();
|
||||
Expression rightExpression = andExpression.getRightExpression();
|
||||
if (isLogicExpression(leftExpression)) {
|
||||
|
||||
removeWhereExpression(leftExpression, removeFieldNames);
|
||||
} else {
|
||||
removeExpressionWithConstant(leftExpression, removeFieldNames);
|
||||
}
|
||||
if (isLogicExpression(rightExpression)) {
|
||||
removeWhereExpression(rightExpression, removeFieldNames);
|
||||
} else {
|
||||
removeExpressionWithConstant(rightExpression, removeFieldNames);
|
||||
}
|
||||
removeExpressionWithConstant(rightExpression, removeFieldNames);
|
||||
} else if (whereExpression instanceof Parenthesis) {
|
||||
removeWhereExpression(((Parenthesis) whereExpression).getExpression(), removeFieldNames);
|
||||
} else {
|
||||
removeExpressionWithConstant(whereExpression, removeFieldNames);
|
||||
}
|
||||
@@ -577,5 +565,21 @@ public class SqlParserUpdateHelper {
|
||||
}
|
||||
return columnName;
|
||||
}
|
||||
|
||||
public static String addParenthesisToWhere(String sql) {
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
Expression where = plainSelect.getWhere();
|
||||
if (Objects.nonNull(where)) {
|
||||
Parenthesis parenthesis = new Parenthesis(where);
|
||||
plainSelect.setWhere(parenthesis);
|
||||
}
|
||||
return selectStatement.toString();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -82,6 +81,13 @@ class SqlParserSelectHelperTest {
|
||||
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
||||
|
||||
System.out.println(filterExpression);
|
||||
|
||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
||||
"SELECT department, pv FROM s2 WHERE "
|
||||
+ "(user_id like '%alice%' AND publish_date > 10000) and sys_imp_date = '2023-08-08' "
|
||||
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
||||
|
||||
System.out.println(filterExpression);
|
||||
}
|
||||
|
||||
|
||||
@@ -118,6 +124,12 @@ class SqlParserSelectHelperTest {
|
||||
|
||||
Assert.assertEquals(allFields.size(), 3);
|
||||
|
||||
allFields = SqlParserSelectHelper.getAllFields(
|
||||
"SELECT department, user_id, field_a FROM s2 WHERE "
|
||||
+ "(user_id = 'alice' AND publish_date = '11') and sys_imp_date "
|
||||
+ "= '2023-08-08' ORDER BY pv DESC LIMIT 1");
|
||||
|
||||
Assert.assertEquals(allFields.size(), 6);
|
||||
}
|
||||
|
||||
|
||||
@@ -150,6 +162,15 @@ class SqlParserSelectHelperTest {
|
||||
Assert.assertEquals(selectFields.contains("发布日期"), true);
|
||||
Assert.assertEquals(selectFields.contains("数据日期"), true);
|
||||
Assert.assertEquals(selectFields.contains("用户"), true);
|
||||
|
||||
sql = "select 部门,用户 from 超音数 where"
|
||||
+ " (用户 = 'alice' and 发布日期 ='11') and 数据日期 = '2023-08-08' "
|
||||
+ "order by 访问次数 limit 1";
|
||||
selectFields = SqlParserSelectHelper.getWhereFields(sql);
|
||||
|
||||
Assert.assertEquals(selectFields.contains("发布日期"), true);
|
||||
Assert.assertEquals(selectFields.contains("数据日期"), true);
|
||||
Assert.assertEquals(selectFields.contains("用户"), true);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -168,34 +189,6 @@ class SqlParserSelectHelperTest {
|
||||
Assert.assertEquals(selectFields.contains("pv"), true);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void addWhere() throws JSQLParserException {
|
||||
|
||||
String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' "
|
||||
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1";
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, "column_a", 123444555);
|
||||
List<String> selectFields = SqlParserSelectHelper.getAllFields(sql);
|
||||
|
||||
Assert.assertEquals(selectFields.contains("column_a"), true);
|
||||
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, "column_b", "123456666");
|
||||
selectFields = SqlParserSelectHelper.getAllFields(sql);
|
||||
|
||||
Assert.assertEquals(selectFields.contains("column_b"), true);
|
||||
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(" ( column_c = 111 or column_d = 1111)");
|
||||
|
||||
sql = SqlParserUpdateHelper.addWhere(
|
||||
"select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' "
|
||||
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1",
|
||||
expression);
|
||||
|
||||
Assert.assertEquals(sql.contains("column_c = 111"), true);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void hasAggregateFunction() throws JSQLParserException {
|
||||
|
||||
|
||||
@@ -57,6 +57,17 @@ class SqlParserUpdateHelperTest {
|
||||
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 数据日期 = '2023-08-09' AND "
|
||||
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
|
||||
|
||||
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09' "
|
||||
+ " order by 播放量 desc limit 11";
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' "
|
||||
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 歌曲发布时 = '2023-08-01') "
|
||||
+ "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -105,6 +116,17 @@ class SqlParserUpdateHelperTest {
|
||||
+ "AND 歌手名 = '林俊杰' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = "
|
||||
+ "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
|
||||
|
||||
replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "and 歌手名 = '林俊杰' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09'"
|
||||
+ " order by 播放量 desc limit 11";
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND "
|
||||
+ "歌手名 = '林俊杰' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 11", replaceSql);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -294,8 +316,13 @@ class SqlParserUpdateHelperTest {
|
||||
|
||||
Assert.assertEquals(sql.contains("column_c = 111"), true);
|
||||
|
||||
}
|
||||
sql = "select 部门,sum (访问次数) from 超音数 where 用户 = alice or 发布日期 ='2023-07-03' group by 部门 limit 1";
|
||||
sql = SqlParserUpdateHelper.addParenthesisToWhere(sql);
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, "数据日期", "2023-08-08");
|
||||
Assert.assertEquals(sql, "SELECT 部门, sum(访问次数) FROM 超音数 WHERE "
|
||||
+ "(用户 = alice OR 发布日期 = '2023-07-03') AND 数据日期 = '2023-08-08' GROUP BY 部门 LIMIT 1");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void replaceFunctionName() {
|
||||
@@ -319,6 +346,16 @@ class SqlParserUpdateHelperTest {
|
||||
"SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE"
|
||||
+ " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)",
|
||||
replaceSql);
|
||||
|
||||
sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where"
|
||||
+ " (datediff('month', 数据日期, '2023-09-02') <= 6) and 数据日期 = '2023-10-10' group by MONTH(数据日期)";
|
||||
replaceSql = SqlParserUpdateHelper.replaceFunction(sql, functionMap);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE"
|
||||
+ " (datediff('month', 数据日期, '2023-09-02') <= 6) AND "
|
||||
+ "数据日期 = '2023-10-10' GROUP BY toMonth(数据日期)",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -332,6 +369,17 @@ class SqlParserUpdateHelperTest {
|
||||
+ "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
|
||||
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' "
|
||||
+ "group by 部门 order by 总访问次数 desc limit 10";
|
||||
replaceSql = SqlParserUpdateHelper.replaceAlias(sql);
|
||||
System.out.println(replaceSql);
|
||||
Assert.assertEquals(
|
||||
"SELECT 部门, sum(访问次数) FROM 超音数 WHERE "
|
||||
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' "
|
||||
+ "GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -356,6 +404,16 @@ class SqlParserUpdateHelperTest {
|
||||
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
|
||||
replaceSql);
|
||||
|
||||
sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND "
|
||||
+ "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
|
||||
havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
|
||||
System.out.println(replaceSql);
|
||||
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') "
|
||||
+ "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
|
||||
replaceSql);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -380,6 +438,15 @@ class SqlParserUpdateHelperTest {
|
||||
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
|
||||
replaceSql);
|
||||
|
||||
sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND "
|
||||
+ "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
|
||||
havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
|
||||
System.out.println(replaceSql);
|
||||
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') "
|
||||
+ "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
|
||||
@@ -458,6 +525,16 @@ class SqlParserUpdateHelperTest {
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND sum(pv) > 1 "
|
||||
+ "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "select department, pv from t_1 where (pv >1 and department = 'HR') "
|
||||
+ " and sys_imp_date = '2023-09-11' GROUP BY department order by pv desc limit 10";
|
||||
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE (sum(pv) > 1 AND department = 'HR') AND "
|
||||
+ "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -474,6 +551,16 @@ class SqlParserUpdateHelperTest {
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "select department, sum(pv) from t_1 where (department = 'HR') and sys_imp_date = '2023-09-11' "
|
||||
+ "order by sum(pv) desc limit 10";
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE (department = 'HR') AND sys_imp_date "
|
||||
+ "= '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -492,6 +579,16 @@ class SqlParserUpdateHelperTest {
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 "
|
||||
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "select department, sum(pv) from t_1 where (sum(pv) > 2000) and sys_imp_date = '2023-09-11' "
|
||||
+ "group by department order by sum(pv) desc limit 10";
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE (2 > 1) AND sys_imp_date = '2023-09-11' "
|
||||
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -520,6 +617,31 @@ class SqlParserUpdateHelperTest {
|
||||
+ "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND "
|
||||
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
|
||||
replaceSql);
|
||||
|
||||
sql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋')) and 数据日期 = '2023-08-09' "
|
||||
+ " order by 播放量 desc limit 11";
|
||||
replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames);
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "AND 1 IN (1) AND 1 IN (1)) AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void addParenthesisToWhere() {
|
||||
String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
|
||||
+ " order by 播放量 desc limit 11";
|
||||
|
||||
String replaceSql = SqlParserUpdateHelper.addParenthesisToWhere(sql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "AND 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01') "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 11",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
private Map<String, String> initParams() {
|
||||
|
||||
Reference in New Issue
Block a user