From 3addfb9a8715a8dec873703383a62200caa20c7d Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Mon, 25 Sep 2023 16:18:58 +0800 Subject: [PATCH] (improvement)(common) support addAggregateToField and addGroupBy and convert metricFilter to Having (#140) --- .../chat/corrector/GlobalCorrector.java | 6 + .../util/jsqlparser/FiledExpression.java | 12 ++ .../jsqlparser/FiledFilterReplaceVisitor.java | 113 ++++++++++++++++++ .../jsqlparser/SqlParserSelectHelper.java | 30 ++++- .../jsqlparser/SqlParserUpdateHelper.java | 111 +++++++++++++++++ .../jsqlparser/SqlParserUpdateHelperTest.java | 82 +++++++++++++ 6 files changed, 348 insertions(+), 6 deletions(-) create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledExpression.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java index 774e12aeb..ec9e819b0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java @@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import java.util.List; import java.util.Map; @@ -32,6 +33,11 @@ public class GlobalCorrector extends BaseSemanticCorrector { private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { + if (SqlParserSelectHelper.hasGroupBy(semanticCorrectInfo.getSql())) { + + return; + } + } private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledExpression.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledExpression.java new file mode 100644 index 000000000..e19ae3d80 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledExpression.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import lombok.Data; + +@Data +public class FiledExpression { + + private String operator; + + private String fieldName; + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java new file mode 100644 index 000000000..9950e8e43 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java @@ -0,0 +1,113 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.expression.operators.relational.MinorThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import org.apache.commons.collections.CollectionUtils; + +@Slf4j +public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { + + private List waitingForAdds = new ArrayList<>(); + private Set fieldNames; + + public FiledFilterReplaceVisitor(Set fieldNames) { + this.fieldNames = fieldNames; + } + + @Override + public void visit(MinorThan expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(EqualsTo expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(MinorThanEquals expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + + @Override + public void visit(GreaterThan expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(GreaterThanEquals expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + public List getWaitingForAdds() { + return waitingForAdds; + } + + + public List parserFilter(ComparisonOperator comparisonOperator) { + List result = new ArrayList<>(); + String toString = comparisonOperator.toString(); + Expression leftExpression = comparisonOperator.getLeftExpression(); + if (!(leftExpression instanceof Function)) { + return result; + } + Function leftExpressionFunction = (Function) leftExpression; + if (leftExpressionFunction.toString().contains(DateFunctionHelper.DATE_FUNCTION)) { + return result; + } + + List leftExpressions = leftExpressionFunction.getParameters().getExpressions(); + if (CollectionUtils.isEmpty(leftExpressions)) { + return result; + } + Column field = (Column) leftExpressions.get(0); + String columnName = field.getColumnName(); + if (!fieldNames.contains(columnName)) { + return null; + } + try { + ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 "); + comparisonOperator.setLeftExpression(expression.getLeftExpression()); + comparisonOperator.setRightExpression(expression.getRightExpression()); + comparisonOperator.setASTNode(expression.getASTNode()); + result.add(CCJSqlParserUtil.parseCondExpression(toString)); + return result; + } catch (JSQLParserException e) { + log.error("JSQLParserException", e); + } + return null; + } + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index 59d40f166..a1e4002be 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -205,6 +205,30 @@ public class SqlParserSelectHelper { public static boolean hasAggregateFunction(String sql) { + if (hasFunction(sql)) { + return true; + } + return hasGroupBy(sql); + } + + public static boolean hasGroupBy(String sql) { + Select selectStatement = getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return false; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + GroupByElement groupBy = plainSelect.getGroupBy(); + if (Objects.nonNull(groupBy)) { + GroupByVisitor replaceVisitor = new GroupByVisitor(); + groupBy.accept(replaceVisitor); + return replaceVisitor.isHasAggregateFunction(); + } + return false; + } + + public static boolean hasFunction(String sql) { Select selectStatement = getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); @@ -221,12 +245,6 @@ public class SqlParserSelectHelper { if (selectFunction) { return true; } - GroupByElement groupBy = plainSelect.getGroupBy(); - if (Objects.nonNull(groupBy)) { - GroupByVisitor replaceVisitor = new GroupByVisitor(); - groupBy.accept(replaceVisitor); - return replaceVisitor.isHasAggregateFunction(); - } return false; } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index b841613b1..48906a2d0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -11,6 +11,7 @@ import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.select.GroupByElement; @@ -20,6 +21,7 @@ import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.util.SelectUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -278,5 +280,114 @@ public class SqlParserUpdateHelper { return selectStatement.toString(); } + public static String addAggregateToField(String sql, Map fieldNameToAggregate) { + if (SqlParserSelectHelper.hasGroupBy(sql)) { + return sql; + } + + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + selectBody.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate); + addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate); + } + }); + return selectStatement.toString(); + } + + public static String addGroupBy(String sql, List groupByFields) { + if (SqlParserSelectHelper.hasGroupBy(sql)) { + return sql; + } + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + + PlainSelect plainSelect = (PlainSelect) selectBody; + GroupByElement groupByElement = new GroupByElement(); + for (String groupByField : groupByFields) { + groupByElement.addGroupByExpression(new Column(groupByField)); + } + plainSelect.setGroupByElement(groupByElement); + return selectStatement.toString(); + } + + private static void addAggregateToSelectItems(List selectItems, + Map fieldNameToAggregate) { + for (SelectItem selectItem : selectItems) { + if (selectItem instanceof SelectExpressionItem) { + SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; + Expression expression = selectExpressionItem.getExpression(); + String columnName = ((Column) expression).getColumnName(); + Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); + if (Objects.isNull(function)) { + continue; + } + selectExpressionItem.setExpression(function); + } + } + } + + private static void addAggregateToOrderByItems(List orderByElements, + Map fieldNameToAggregate) { + if (orderByElements == null) { + return; + } + for (OrderByElement orderByElement : orderByElements) { + Expression expression = orderByElement.getExpression(); + String columnName = ((Column) expression).getColumnName(); + if (StringUtils.isEmpty(columnName)) { + continue; + } + Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); + if (Objects.isNull(function)) { + continue; + } + orderByElement.setExpression(function); + } + } + + private static Function getFunction(Expression expression, String aggregateName) { + if (StringUtils.isEmpty(aggregateName)) { + return null; + } + Function sumFunction = new Function(); + sumFunction.setName(aggregateName); + sumFunction.setParameters(new ExpressionList(expression)); + return sumFunction; + } + + public static String addHaving(String sql, Set fieldNames) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + + PlainSelect plainSelect = (PlainSelect) selectBody; + //replace metric to 1 and 1 and add having metric + Expression where = plainSelect.getWhere(); + FiledFilterReplaceVisitor visitor = new FiledFilterReplaceVisitor(fieldNames); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + List waitingForAdds = visitor.getWaitingForAdds(); + if (!CollectionUtils.isEmpty(waitingForAdds)) { + for (Expression waitingForAdd : waitingForAdds) { + plainSelect.setHaving(waitingForAdd); + } + } + return selectStatement.toString(); + } } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index 77c0ab4c2..8745235e8 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -1,9 +1,12 @@ package com.tencent.supersonic.common.util.jsqlparser; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.parser.CCJSqlParserUtil; @@ -266,6 +269,85 @@ class SqlParserUpdateHelperTest { } + @Test + void addAggregateToField() { + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + } + + + @Test + void addAggregateToMetricField() { + String sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' order by pv desc limit 10"; + + Map filedNameToAggregate = new HashMap<>(); + filedNameToAggregate.put("pv", "sum"); + + List groupByFields = new ArrayList<>(); + groupByFields.add("department"); + + String replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + @Test + void addGroupBy() { + String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' " + + "order by sum(pv) desc limit 10"; + + List groupByFields = new ArrayList<>(); + groupByFields.add("department"); + + String replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + @Test + void addHaving() { + String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and " + + "sum(pv) > 2000 group by department order by sum(pv) desc limit 10"; + List groupByFields = new ArrayList<>(); + groupByFields.add("department"); + + Set fieldNames = new HashSet<>(); + fieldNames.add("pv"); + + String replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND 1 > 1 GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + private Map initParams() { Map fieldToBizName = new HashMap<>(); fieldToBizName.put("部门", "department");