From 2fe56e74625512fda755ada8728edc9da06cf9c9 Mon Sep 17 00:00:00 2001 From: yangde <34122685+yonyong@users.noreply.github.com> Date: Sat, 4 Nov 2023 12:58:25 +0800 Subject: [PATCH] (improvement)(chat) Special handling for count_distinct operator during SQL correcting and explaining (#320) --- .../common/pojo/enums/AggOperatorEnum.java | 11 ++ .../SqlParserSelectFunctionHelper.java | 9 +- .../jsqlparser/SqlParserAddHelperTest.java | 100 ++++++++++++++++++ .../parser/convert/QueryReqConverter.java | 3 +- 4 files changed, 121 insertions(+), 2 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java index 5931dd76a..b6c0125cb 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java @@ -41,5 +41,16 @@ public enum AggOperatorEnum { return AggOperatorEnum.UNKNOWN; } + /** + * Determine if aggType is count_Distinct type + * 1.outer SQL parses the count_distinct(field) operator as count(DISTINCT field). + * 2.tableSQL generates aggregation that ignores the count_distinct operator. + * @param aggType aggType + * @return is count_Distinct type or not + */ + public static boolean isCountDistinct(String aggType) { + return null != aggType && aggType.toUpperCase().equals(COUNT_DISTINCT.getOperator()); + } + } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java index 4f54a7470..a5db686bf 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java @@ -5,6 +5,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; + +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; @@ -74,7 +76,12 @@ public class SqlParserSelectFunctionHelper { return null; } Function sumFunction = new Function(); - sumFunction.setName(aggregateName); + if (AggOperatorEnum.isCountDistinct(aggregateName)) { + sumFunction.setName("count"); + sumFunction.setDistinct(true); + } else { + sumFunction.setName(aggregateName); + } sumFunction.setParameters(new ExpressionList(expression)); return sumFunction; } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java index 4b63a7a79..3d763dc9f 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java @@ -217,6 +217,106 @@ class SqlParserAddHelperTest { replaceSql); } + + @Test + void addAggregateToCountDiscountMetricField() { + String sql = "select department, uv from t_1 where sys_imp_date = '2023-09-11' order by uv desc limit 10"; + + Map filedNameToAggregate = new HashMap<>(); + filedNameToAggregate.put("uv", "count_distinct"); + + Set groupByFields = new HashSet<>(); + groupByFields.add("department"); + + String replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + + sql = "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 " + + "order by uv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND count(DISTINCT uv) > 1 GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + + sql = "select department, uv from t_1 where uv >1 order by uv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE count(DISTINCT uv) > 1 " + + "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + + sql = "select department, uv from t_1 where count(DISTINCT uv) >1 order by uv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE count(DISTINCT uv) > 1 " + + "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + + sql = "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11' and count(DISTINCT uv) >1 " + + "GROUP BY department order by count(DISTINCT uv) desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND count(DISTINCT uv) > 1 GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + + sql = "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 " + + "GROUP BY department order by count(DISTINCT uv) desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND count(DISTINCT uv) > 1 GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + + sql = "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 and department = 'HR' " + + "GROUP BY department order by uv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND count(DISTINCT uv) > 1 " + + "AND department = 'HR' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + + sql = "select department, uv from t_1 where (uv >1 and department = 'HR') " + + " and sys_imp_date = '2023-09-11' GROUP BY department order by uv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE (count(DISTINCT uv) > 1 AND department = 'HR') AND " + + "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + + sql = "select department, count(DISTINCT uv) as uv from t_1 where sys_imp_date = '2023-09-11' GROUP BY " + + "department order by uv desc limit 10"; + replaceSql = SqlParserReplaceHelper.replaceAlias(sql); + replaceSql = SqlParserAddHelper.addAggregateToField(replaceSql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, count(DISTINCT uv) AS uv " + + "FROM t_1 WHERE sys_imp_date = '2023-09-11' GROUP BY department " + + "ORDER BY count(DISTINCT uv) DESC LIMIT 10", + replaceSql); + } + @Test void addGroupBy() { String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' " diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java index 13dfa3553..3db4f361e 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java @@ -125,7 +125,8 @@ public class QueryReqConverter { // if there is count() in S2QL,set MetricTable's aggOption to "NATIVE" String sql = databaseReq.getSql(); if (!SqlParserSelectHelper.hasGroupBy(sql) - || SqlParserSelectFunctionHelper.hasFunction(sql, "count")) { + || SqlParserSelectFunctionHelper.hasFunction(sql, "count") + || SqlParserSelectFunctionHelper.hasFunction(sql, "count_distinct")) { return AggOption.NATIVE; } return AggOption.DEFAULT;