(improvement)(chat) add default aggregate to all metric and add group by to dimension and add metric filter in having (#150)

This commit is contained in:
lexluo09
2023-09-27 00:05:45 +08:00
committed by GitHub
parent ff5479f1a2
commit 24e8e756de
18 changed files with 327 additions and 85 deletions

View File

@@ -31,7 +31,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(MinorThan expr) {
List<Expression> expressions = parserFilter(expr);
List<Expression> expressions = parserFilter(expr, " 1 < 2 ");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -39,7 +39,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(EqualsTo expr) {
List<Expression> expressions = parserFilter(expr);
List<Expression> expressions = parserFilter(expr, " 1 = 1 ");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -47,7 +47,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(MinorThanEquals expr) {
List<Expression> expressions = parserFilter(expr);
List<Expression> expressions = parserFilter(expr, " 1 <= 1 ");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -56,7 +56,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(GreaterThan expr) {
List<Expression> expressions = parserFilter(expr);
List<Expression> expressions = parserFilter(expr, " 2 > 1 ");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -64,7 +64,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(GreaterThanEquals expr) {
List<Expression> expressions = parserFilter(expr);
List<Expression> expressions = parserFilter(expr, " 1 >= 1 ");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -75,7 +75,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
}
public List<Expression> parserFilter(ComparisonOperator comparisonOperator) {
public List<Expression> parserFilter(ComparisonOperator comparisonOperator, String condExpr) {
List<Expression> result = new ArrayList<>();
String toString = comparisonOperator.toString();
Expression leftExpression = comparisonOperator.getLeftExpression();
@@ -97,7 +97,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
return null;
}
try {
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 ");
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
comparisonOperator.setLeftExpression(expression.getLeftExpression());
comparisonOperator.setRightExpression(expression.getRightExpression());
comparisonOperator.setASTNode(expression.getASTNode());

View File

@@ -10,6 +10,9 @@ import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.conditional.XorExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.schema.Column;
@@ -281,10 +284,6 @@ public class SqlParserUpdateHelper {
}
public static String addAggregateToField(String sql, Map<String, String> fieldNameToAggregate) {
if (SqlParserSelectHelper.hasGroupBy(sql)) {
return sql;
}
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
@@ -296,13 +295,15 @@ public class SqlParserUpdateHelper {
public void visit(PlainSelect plainSelect) {
addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate);
addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate);
addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate);
addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate);
}
});
return selectStatement.toString();
}
public static String addGroupBy(String sql, List<String> groupByFields) {
if (SqlParserSelectHelper.hasGroupBy(sql)) {
public static String addGroupBy(String sql, Set<String> groupByFields) {
if (SqlParserSelectHelper.hasGroupBy(sql) || CollectionUtils.isEmpty(groupByFields)) {
return sql;
}
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
@@ -327,9 +328,8 @@ public class SqlParserUpdateHelper {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
Expression expression = selectExpressionItem.getExpression();
String columnName = ((Column) expression).getColumnName();
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
if (Objects.isNull(function)) {
Function function = getFunction(expression, fieldNameToAggregate);
if (function == null) {
continue;
}
selectExpressionItem.setExpression(function);
@@ -344,18 +344,102 @@ public class SqlParserUpdateHelper {
}
for (OrderByElement orderByElement : orderByElements) {
Expression expression = orderByElement.getExpression();
String columnName = ((Column) expression).getColumnName();
if (StringUtils.isEmpty(columnName)) {
continue;
}
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
if (Objects.isNull(function)) {
Function function = getFunction(expression, fieldNameToAggregate);
if (function == null) {
continue;
}
orderByElement.setExpression(function);
}
}
private static void addAggregateToGroupByItems(GroupByElement groupByElement,
Map<String, String> fieldNameToAggregate) {
if (groupByElement == null) {
return;
}
for (Expression expression : groupByElement.getGroupByExpressions()) {
Function function = getFunction(expression, fieldNameToAggregate);
if (function == null) {
continue;
}
groupByElement.addGroupByExpression(function);
}
}
private static void addAggregateToWhereItems(Expression whereExpression, Map<String, String> fieldNameToAggregate) {
if (whereExpression == null) {
return;
}
modifyWhereExpression(whereExpression, fieldNameToAggregate);
}
private static void modifyWhereExpression(Expression whereExpression,
Map<String, String> fieldNameToAggregate) {
if (isLogicExpression(whereExpression)) {
AndExpression andExpression = (AndExpression) whereExpression;
Expression leftExpression = andExpression.getLeftExpression();
Expression rightExpression = andExpression.getRightExpression();
if (isLogicExpression(leftExpression)) {
modifyWhereExpression(leftExpression, fieldNameToAggregate);
} else {
setAggToFunction(leftExpression, fieldNameToAggregate);
}
if (isLogicExpression(rightExpression)) {
modifyWhereExpression(rightExpression, fieldNameToAggregate);
} else {
setAggToFunction(rightExpression, fieldNameToAggregate);
}
setAggToFunction(rightExpression, fieldNameToAggregate);
} else {
setAggToFunction(whereExpression, fieldNameToAggregate);
}
}
private static boolean isLogicExpression(Expression whereExpression) {
return whereExpression instanceof AndExpression || (whereExpression instanceof OrExpression
|| (whereExpression instanceof XorExpression));
}
private static void setAggToFunction(Expression expression, Map<String, String> fieldNameToAggregate) {
if (!(expression instanceof ComparisonOperator)) {
return;
}
ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
if (comparisonOperator.getRightExpression() instanceof Column) {
String columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName();
Function function = getFunction(comparisonOperator.getRightExpression(),
fieldNameToAggregate.get(columnName));
if (Objects.nonNull(function)) {
comparisonOperator.setRightExpression(function);
}
}
if (comparisonOperator.getLeftExpression() instanceof Column) {
String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName();
Function function = getFunction(comparisonOperator.getLeftExpression(),
fieldNameToAggregate.get(columnName));
if (Objects.nonNull(function)) {
comparisonOperator.setLeftExpression(function);
}
}
}
private static Function getFunction(Expression expression, Map<String, String> fieldNameToAggregate) {
if (!(expression instanceof Column)) {
return null;
}
String columnName = ((Column) expression).getColumnName();
if (StringUtils.isEmpty(columnName)) {
return null;
}
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
if (Objects.isNull(function)) {
return null;
}
return function;
}
private static Function getFunction(Expression expression, String aggregateName) {
if (StringUtils.isEmpty(aggregateName)) {
return null;

View File

@@ -301,7 +301,7 @@ class SqlParserUpdateHelperTest {
Map<String, String> filedNameToAggregate = new HashMap<>();
filedNameToAggregate.put("pv", "sum");
List<String> groupByFields = new ArrayList<>();
Set<String> groupByFields = new HashSet<>();
groupByFields.add("department");
String replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
@@ -311,6 +311,66 @@ class SqlParserUpdateHelperTest {
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 "
+ "order by pv desc limit 10";
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, pv from t_1 where pv >1 order by pv desc limit 10";
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 "
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, pv from t_1 where sum(pv) >1 order by pv desc limit 10";
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 "
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and sum(pv) >1 "
+ "GROUP BY department order by pv desc limit 10";
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 "
+ "GROUP BY department order by pv desc limit 10";
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 and department = 'HR' "
+ "GROUP BY department order by pv desc limit 10";
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND sum(pv) > 1 "
+ "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
}
@Test
@@ -318,7 +378,7 @@ class SqlParserUpdateHelperTest {
String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' "
+ "order by sum(pv) desc limit 10";
List<String> groupByFields = new ArrayList<>();
Set<String> groupByFields = new HashSet<>();
groupByFields.add("department");
String replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields);
@@ -342,8 +402,8 @@ class SqlParserUpdateHelperTest {
String replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "AND 1 > 1 GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 "
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
}