From 4ccee8b1077afda8b3869d2efae75485c427fc10 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sat, 7 Oct 2023 21:51:37 +0800 Subject: [PATCH] (improvement)(chat) support remove where condition and fix simplifySql space error and addAggregateToMetric optimize (#170) --- .../chat/corrector/GroupByCorrector.java | 2 +- .../knowledge/utils/NatureHelper.java | 3 + .../util/jsqlparser/DateFunctionHelper.java | 11 +-- .../jsqlparser/FiledFilterReplaceVisitor.java | 12 ++-- .../jsqlparser/FunctionReplaceVisitor.java | 6 +- .../common/util/jsqlparser/JsqlConstants.java | 18 +++++ .../jsqlparser/SqlParserUpdateHelper.java | 72 +++++++++++++++++++ .../jsqlparser/SqlParserUpdateHelperTest.java | 17 +++++ .../parser/calcite/sql/node/SemanticNode.java | 11 ++- 9 files changed, 128 insertions(+), 24 deletions(-) create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java index a09075ce4..8e715f4e5 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -30,7 +30,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { if (!CollectionUtils.isEmpty(selectFields) && !CollectionUtils.isEmpty(metrics) - && !selectFields.stream().anyMatch(s -> metrics.contains(s))) { + && selectFields.stream().anyMatch(s -> metrics.contains(s))) { //add aggregate to all metric addAggregateToMetric(semanticCorrectInfo); } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/NatureHelper.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/NatureHelper.java index 0724108d2..c3e945e15 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/NatureHelper.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/NatureHelper.java @@ -23,6 +23,9 @@ public class NatureHelper { public static SchemaElementType convertToElementType(String nature) { DictWordType dictWordType = DictWordType.getNatureType(nature); + if (Objects.isNull(dictWordType)) { + return null; + } SchemaElementType result = null; switch (dictWordType) { case METRIC: diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/DateFunctionHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/DateFunctionHelper.java index 534374e2e..15b1a7666 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/DateFunctionHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/DateFunctionHelper.java @@ -13,11 +13,6 @@ import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; @Slf4j public class DateFunctionHelper { - public static final String DATE_FUNCTION = "datediff"; - public static final double HALF_YEAR = 0.5d; - public static final int SIX_MONTH = 6; - public static final String EQUAL = "="; - public static String getStartDateStr(ComparisonOperator minorThanEquals, List expressions) { String unitValue = getUnit(expressions); String dateValue = getEndDateValue(expressions); @@ -27,9 +22,9 @@ public class DateFunctionHelper { if (rightExpression instanceof DoubleValue) { DoubleValue value = (DoubleValue) rightExpression; double doubleValue = value.getValue(); - if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == HALF_YEAR) { + if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == JsqlConstants.HALF_YEAR) { datePeriodEnum = DatePeriodEnum.MONTH; - dateStr = DateUtils.getBeforeDate(dateValue, SIX_MONTH, datePeriodEnum); + dateStr = DateUtils.getBeforeDate(dateValue, JsqlConstants.SIX_MONTH, datePeriodEnum); } } else if (rightExpression instanceof LongValue) { LongValue value = (LongValue) rightExpression; @@ -41,7 +36,7 @@ public class DateFunctionHelper { public static String getEndDateOperator(ComparisonOperator comparisonOperator) { String operator = comparisonOperator.getStringExpression(); - if (EQUAL.equalsIgnoreCase(operator)) { + if (JsqlConstants.EQUAL.equalsIgnoreCase(operator)) { operator = "<="; } return operator; diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java index b3f88d756..a3c968177 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java @@ -31,7 +31,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(MinorThan expr) { - List expressions = parserFilter(expr, " 1 < 2 "); + List expressions = parserFilter(expr, JsqlConstants.MINOR_THAN_CONSTANT); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -39,7 +39,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(EqualsTo expr) { - List expressions = parserFilter(expr, " 1 = 1 "); + List expressions = parserFilter(expr, JsqlConstants.EQUAL_CONSTANT); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -47,7 +47,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(MinorThanEquals expr) { - List expressions = parserFilter(expr, " 1 <= 1 "); + List expressions = parserFilter(expr, JsqlConstants.MINOR_THAN_EQUALS_CONSTANT); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -56,7 +56,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(GreaterThan expr) { - List expressions = parserFilter(expr, " 2 > 1 "); + List expressions = parserFilter(expr, JsqlConstants.GREATER_THAN_CONSTANT); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -64,7 +64,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(GreaterThanEquals expr) { - List expressions = parserFilter(expr, " 1 >= 1 "); + List expressions = parserFilter(expr, JsqlConstants.GREATER_THAN_EQUALS_CONSTANT); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -83,7 +83,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { return result; } Function leftExpressionFunction = (Function) leftExpression; - if (leftExpressionFunction.toString().contains(DateFunctionHelper.DATE_FUNCTION)) { + if (leftExpressionFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) { return result; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java index 8fcd0afe4..dee99a1df 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java @@ -77,7 +77,7 @@ public class FunctionReplaceVisitor extends ExpressionVisitorAdapter { return result; } Function leftExpressionFunction = (Function) leftExpression; - if (!leftExpressionFunction.toString().contains(DateFunctionHelper.DATE_FUNCTION)) { + if (!leftExpressionFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) { return result; } List leftExpressions = leftExpressionFunction.getParameters().getExpressions(); @@ -98,9 +98,9 @@ public class FunctionReplaceVisitor extends ExpressionVisitorAdapter { String startDataCondExpr = columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue); - if (DateFunctionHelper.EQUAL.equalsIgnoreCase(endDateOperator)) { + if (JsqlConstants.EQUAL.equalsIgnoreCase(endDateOperator)) { result.add(CCJSqlParserUtil.parseCondExpression(condExpr)); - expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 "); + expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(JsqlConstants.EQUAL_CONSTANT); } comparisonOperator.setLeftExpression(null); comparisonOperator.setRightExpression(null); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java new file mode 100644 index 000000000..2d20478c0 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class JsqlConstants { + + public static final String DATE_FUNCTION = "datediff"; + public static final double HALF_YEAR = 0.5d; + public static final int SIX_MONTH = 6; + public static final String EQUAL = "="; + public static final String MINOR_THAN_CONSTANT = " 1 < 2 "; + public static final String MINOR_THAN_EQUALS_CONSTANT = " 1 <= 1 "; + public static final String GREATER_THAN_CONSTANT = " 2 > 1 "; + public static final String GREATER_THAN_EQUALS_CONSTANT = " 1 >= 1 "; + public static final String EQUAL_CONSTANT = " 1 = 1 "; + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index 7c8095a96..d8d294bfc 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -5,6 +5,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.LongValue; @@ -15,6 +16,7 @@ import net.sf.jsqlparser.expression.operators.conditional.XorExpression; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.select.GroupByElement; @@ -478,5 +480,75 @@ public class SqlParserUpdateHelper { } return selectStatement.toString(); } + + public static String removeWhereCondition(String sql, Set removeFieldNames) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + selectBody.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + removeWhereCondition(plainSelect.getWhere(), removeFieldNames); + } + }); + return selectStatement.toString(); + } + + private static void removeWhereCondition(Expression whereExpression, Set removeFieldNames) { + if (whereExpression == null) { + return; + } + removeWhereExpression(whereExpression, removeFieldNames); + } + + private static void removeWhereExpression(Expression whereExpression, Set removeFieldNames) { + if (isLogicExpression(whereExpression)) { + AndExpression andExpression = (AndExpression) whereExpression; + Expression leftExpression = andExpression.getLeftExpression(); + Expression rightExpression = andExpression.getRightExpression(); + if (isLogicExpression(leftExpression)) { + removeWhereExpression(leftExpression, removeFieldNames); + } else { + removeExpressionWithConstant(leftExpression, removeFieldNames); + } + if (isLogicExpression(rightExpression)) { + removeWhereExpression(rightExpression, removeFieldNames); + } else { + removeExpressionWithConstant(rightExpression, removeFieldNames); + } + removeExpressionWithConstant(rightExpression, removeFieldNames); + } else { + removeExpressionWithConstant(whereExpression, removeFieldNames); + } + } + + private static void removeExpressionWithConstant(Expression expression, Set removeFieldNames) { + if (!(expression instanceof EqualsTo)) { + return; + } + ComparisonOperator comparisonOperator = (ComparisonOperator) expression; + String columnName = ""; + if (comparisonOperator.getRightExpression() instanceof Column) { + columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName(); + } + if (comparisonOperator.getLeftExpression() instanceof Column) { + columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName(); + } + if (!removeFieldNames.contains(columnName)) { + return; + } + try { + ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression( + JsqlConstants.EQUAL_CONSTANT); + comparisonOperator.setLeftExpression(constantExpression.getLeftExpression()); + comparisonOperator.setRightExpression(constantExpression.getRightExpression()); + comparisonOperator.setASTNode(constantExpression.getASTNode()); + } catch (JSQLParserException e) { + log.error("JSQLParserException", e); + } + } } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index e93f0b61d..930e1fa11 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -453,6 +453,23 @@ class SqlParserUpdateHelperTest { replaceSql); } + @Test + void removeWhereCondition() { + String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + Set removeFieldNames = new HashSet<>(); + removeFieldNames.add("歌曲名"); + + String replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + + "ORDER BY 播放量 DESC LIMIT 11", + replaceSql); + } private Map initParams() { Map fieldToBizName = new HashMap<>(); diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/node/SemanticNode.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/node/SemanticNode.java index e510ba8f2..fff931c1d 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/node/SemanticNode.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/node/SemanticNode.java @@ -2,13 +2,14 @@ package com.tencent.supersonic.semantic.query.parser.calcite.sql.node; import com.tencent.supersonic.semantic.query.parser.calcite.Configuration; -import com.tencent.supersonic.semantic.query.parser.calcite.sql.Optimization; import com.tencent.supersonic.semantic.query.parser.calcite.schema.SemanticSqlDialect; +import com.tencent.supersonic.semantic.query.parser.calcite.sql.Optimization; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; import org.apache.calcite.sql.SqlAsOperator; import org.apache.calcite.sql.SqlBasicCall; @@ -17,7 +18,6 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.SqlWriterConfig; -import org.apache.calcite.sql.advise.SqlSimpleParser; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.pretty.SqlPrettyWriter; @@ -41,13 +41,12 @@ public abstract class SemanticNode { } public static String getSql(SqlNode sqlNode) { - SqlSimpleParser sqlSimpleParser = new SqlSimpleParser("", Configuration.getParserConfig()); SqlWriterConfig config = SqlPrettyWriter.config().withDialect(SemanticSqlDialect.DEFAULT) .withKeywordsLowerCase(true).withClauseEndsLine(true).withAlwaysUseParentheses(false) .withSelectListItemsOnSeparateLines(false).withUpdateSetListNewline(false).withIndentation(0); - return sqlSimpleParser.simplifySql(sqlNode.toSqlString((c) -> { - return config; - }).getSql()); + + UnaryOperator sqlWriterConfigUnaryOperator = (c) -> config; + return sqlNode.toSqlString(sqlWriterConfigUnaryOperator).getSql(); } public static boolean isNumeric(String expr) {