From a87304b22b807b2d30410cf9d411fa49065df64f Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:54:51 +0800 Subject: [PATCH] (improvement)(chat) dsl corrector support add agg metric in having (#95) --- .../corrector/SelectFieldAppendCorrector.java | 8 +++++ .../SelectFieldAppendCorrectorTest.java | 20 +++++++++++++ .../jsqlparser/SqlParserSelectHelper.java | 19 ++++++++++++ .../jsqlparser/SqlParserUpdateHelper.java | 29 +++++++++++++++++++ .../jsqlparser/SqlParserSelectHelperTest.java | 11 +++++++ .../jsqlparser/SqlParserUpdateHelperTest.java | 24 +++++++++++++++ pom.xml | 2 +- 7 files changed, 112 insertions(+), 1 deletion(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java index e906de510..5476370fb 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java @@ -6,8 +6,10 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import java.util.ArrayList; import java.util.HashSet; +import java.util.Objects; import java.util.Set; import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.expression.Expression; import org.springframework.util.CollectionUtils; @Slf4j @@ -17,6 +19,12 @@ public class SelectFieldAppendCorrector extends BaseSemanticCorrector { public void correct(SemanticCorrectInfo semanticCorrectInfo) { String preSql = semanticCorrectInfo.getSql(); if (SqlParserSelectHelper.hasAggregateFunction(preSql)) { + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql); + if (Objects.nonNull(havingExpression)) { + String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(preSql, havingExpression); + semanticCorrectInfo.setPreSql(preSql); + semanticCorrectInfo.setSql(replaceSql); + } return; } Set selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql)); diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java index 13d648a34..39db3935d 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java @@ -22,5 +22,25 @@ class SelectFieldAppendCorrectorTest { + "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'" + " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql()); + semanticCorrectInfo.setSql("select 用户名 from 内容库产品 where datediff('day', 数据日期, '2023-09-14') <= 30" + + " group by 用户名 having sum(访问次数) > 2000"); + + corrector.correct(semanticCorrectInfo); + + Assert.assertEquals( + "SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " + + "datediff('day', 数据日期, '2023-09-14') <= 30 " + + "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql()); + + semanticCorrectInfo.setSql("SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " + + "datediff('day', 数据日期, '2023-09-14') <= 30 " + + "GROUP BY 用户名 HAVING sum(访问次数) > 2000"); + + corrector.correct(semanticCorrectInfo); + + Assert.assertEquals( + "SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " + + "datediff('day', 数据日期, '2023-09-14') <= 30 " + + "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql()); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index 7ac7d2ea2..59d40f166 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -8,6 +8,8 @@ import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; @@ -130,6 +132,23 @@ public class SqlParserSelectHelper { } + public static Expression getHavingExpression(String sql) { + PlainSelect plainSelect = getPlainSelect(sql); + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + if (!(having instanceof ComparisonOperator)) { + return null; + } + ComparisonOperator comparisonOperator = (ComparisonOperator) having; + if (comparisonOperator.getLeftExpression() instanceof Function) { + return comparisonOperator.getLeftExpression(); + } else if (comparisonOperator.getRightExpression() instanceof Function) { + return comparisonOperator.getRightExpression(); + } + } + return null; + } + public static List getOrderByFields(String sql) { PlainSelect plainSelect = getPlainSelect(sql); if (Objects.isNull(plainSelect)) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index 16fc5ac8c..b841613b1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -6,6 +6,7 @@ import java.util.Objects; import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; @@ -17,6 +18,7 @@ import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.util.SelectUtils; import org.apache.commons.lang3.StringUtils; @@ -172,6 +174,33 @@ public class SqlParserUpdateHelper { return selectStatement.toString(); } + public static String addFunctionToSelect(String sql, Expression expression) { + PlainSelect plainSelect = SqlParserSelectHelper.getPlainSelect(sql); + if (Objects.isNull(plainSelect)) { + return sql; + } + List selectItems = plainSelect.getSelectItems(); + if (CollectionUtils.isEmpty(selectItems)) { + return sql; + } + boolean existFunction = false; + for (SelectItem selectItem : selectItems) { + SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem; + if (expressionItem.getExpression() instanceof Function) { + Function expressionFunction = (Function) expressionItem.getExpression(); + if (expression.toString().equalsIgnoreCase(expressionFunction.toString())) { + existFunction = true; + break; + } + } + } + if (!existFunction) { + SelectExpressionItem sumExpressionItem = new SelectExpressionItem(expression); + selectItems.add(sumExpressionItem); + } + return plainSelect.toString(); + } + public static String replaceTable(String sql, String tableName) { if (StringUtils.isEmpty(tableName)) { return sql; diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index fcb702f9b..65d870252 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -250,4 +250,15 @@ class SqlParserSelectHelperTest { } + @Test + void getHavingExpression() { + + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + Expression leftExpression = SqlParserSelectHelper.getHavingExpression(sql); + + Assert.assertEquals(leftExpression.toString(), "sum(pv)"); + + } + } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index 0482eceed..77c0ab4c2 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -242,6 +242,30 @@ class SqlParserUpdateHelperTest { } + @Test + void addFunctionToSelect() { + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + } + private Map initParams() { Map fieldToBizName = new HashMap<>(); fieldToBizName.put("部门", "department"); diff --git a/pom.xml b/pom.xml index 824e671d8..c9bdfff2d 100644 --- a/pom.xml +++ b/pom.xml @@ -64,7 +64,7 @@ 3.2.4 4.5.1 4.5 - 0.7.4 + 0.7.5-SNAPSHOT