(improvement)(chat) support remove where condition and fix simplifySql space error and addAggregateToMetric optimize (#170)

This commit is contained in:
lexluo09
2023-10-07 21:51:37 +08:00
committed by GitHub
parent eccd791a39
commit 4ccee8b107
9 changed files with 128 additions and 24 deletions

View File

@@ -30,7 +30,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
if (!CollectionUtils.isEmpty(selectFields)
&& !CollectionUtils.isEmpty(metrics)
&& !selectFields.stream().anyMatch(s -> metrics.contains(s))) {
&& selectFields.stream().anyMatch(s -> metrics.contains(s))) {
//add aggregate to all metric
addAggregateToMetric(semanticCorrectInfo);
}

View File

@@ -23,6 +23,9 @@ public class NatureHelper {
public static SchemaElementType convertToElementType(String nature) {
DictWordType dictWordType = DictWordType.getNatureType(nature);
if (Objects.isNull(dictWordType)) {
return null;
}
SchemaElementType result = null;
switch (dictWordType) {
case METRIC:

View File

@@ -13,11 +13,6 @@ import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
@Slf4j
public class DateFunctionHelper {
public static final String DATE_FUNCTION = "datediff";
public static final double HALF_YEAR = 0.5d;
public static final int SIX_MONTH = 6;
public static final String EQUAL = "=";
public static String getStartDateStr(ComparisonOperator minorThanEquals, List<Expression> expressions) {
String unitValue = getUnit(expressions);
String dateValue = getEndDateValue(expressions);
@@ -27,9 +22,9 @@ public class DateFunctionHelper {
if (rightExpression instanceof DoubleValue) {
DoubleValue value = (DoubleValue) rightExpression;
double doubleValue = value.getValue();
if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == HALF_YEAR) {
if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == JsqlConstants.HALF_YEAR) {
datePeriodEnum = DatePeriodEnum.MONTH;
dateStr = DateUtils.getBeforeDate(dateValue, SIX_MONTH, datePeriodEnum);
dateStr = DateUtils.getBeforeDate(dateValue, JsqlConstants.SIX_MONTH, datePeriodEnum);
}
} else if (rightExpression instanceof LongValue) {
LongValue value = (LongValue) rightExpression;
@@ -41,7 +36,7 @@ public class DateFunctionHelper {
public static String getEndDateOperator(ComparisonOperator comparisonOperator) {
String operator = comparisonOperator.getStringExpression();
if (EQUAL.equalsIgnoreCase(operator)) {
if (JsqlConstants.EQUAL.equalsIgnoreCase(operator)) {
operator = "<=";
}
return operator;

View File

@@ -31,7 +31,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(MinorThan expr) {
List<Expression> expressions = parserFilter(expr, " 1 < 2 ");
List<Expression> expressions = parserFilter(expr, JsqlConstants.MINOR_THAN_CONSTANT);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -39,7 +39,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(EqualsTo expr) {
List<Expression> expressions = parserFilter(expr, " 1 = 1 ");
List<Expression> expressions = parserFilter(expr, JsqlConstants.EQUAL_CONSTANT);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -47,7 +47,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(MinorThanEquals expr) {
List<Expression> expressions = parserFilter(expr, " 1 <= 1 ");
List<Expression> expressions = parserFilter(expr, JsqlConstants.MINOR_THAN_EQUALS_CONSTANT);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -56,7 +56,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(GreaterThan expr) {
List<Expression> expressions = parserFilter(expr, " 2 > 1 ");
List<Expression> expressions = parserFilter(expr, JsqlConstants.GREATER_THAN_CONSTANT);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -64,7 +64,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
@Override
public void visit(GreaterThanEquals expr) {
List<Expression> expressions = parserFilter(expr, " 1 >= 1 ");
List<Expression> expressions = parserFilter(expr, JsqlConstants.GREATER_THAN_EQUALS_CONSTANT);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
@@ -83,7 +83,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
return result;
}
Function leftExpressionFunction = (Function) leftExpression;
if (leftExpressionFunction.toString().contains(DateFunctionHelper.DATE_FUNCTION)) {
if (leftExpressionFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) {
return result;
}

View File

@@ -77,7 +77,7 @@ public class FunctionReplaceVisitor extends ExpressionVisitorAdapter {
return result;
}
Function leftExpressionFunction = (Function) leftExpression;
if (!leftExpressionFunction.toString().contains(DateFunctionHelper.DATE_FUNCTION)) {
if (!leftExpressionFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) {
return result;
}
List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
@@ -98,9 +98,9 @@ public class FunctionReplaceVisitor extends ExpressionVisitorAdapter {
String startDataCondExpr =
columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue);
if (DateFunctionHelper.EQUAL.equalsIgnoreCase(endDateOperator)) {
if (JsqlConstants.EQUAL.equalsIgnoreCase(endDateOperator)) {
result.add(CCJSqlParserUtil.parseCondExpression(condExpr));
expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 ");
expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(JsqlConstants.EQUAL_CONSTANT);
}
comparisonOperator.setLeftExpression(null);
comparisonOperator.setRightExpression(null);

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.common.util.jsqlparser;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class JsqlConstants {
public static final String DATE_FUNCTION = "datediff";
public static final double HALF_YEAR = 0.5d;
public static final int SIX_MONTH = 6;
public static final String EQUAL = "=";
public static final String MINOR_THAN_CONSTANT = " 1 < 2 ";
public static final String MINOR_THAN_EQUALS_CONSTANT = " 1 <= 1 ";
public static final String GREATER_THAN_CONSTANT = " 2 > 1 ";
public static final String GREATER_THAN_EQUALS_CONSTANT = " 1 >= 1 ";
public static final String EQUAL_CONSTANT = " 1 = 1 ";
}

View File

@@ -5,6 +5,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
@@ -15,6 +16,7 @@ import net.sf.jsqlparser.expression.operators.conditional.XorExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.GroupByElement;
@@ -478,5 +480,75 @@ public class SqlParserUpdateHelper {
}
return selectStatement.toString();
}
public static String removeWhereCondition(String sql, Set<String> removeFieldNames) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
selectBody.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
removeWhereCondition(plainSelect.getWhere(), removeFieldNames);
}
});
return selectStatement.toString();
}
private static void removeWhereCondition(Expression whereExpression, Set<String> removeFieldNames) {
if (whereExpression == null) {
return;
}
removeWhereExpression(whereExpression, removeFieldNames);
}
private static void removeWhereExpression(Expression whereExpression, Set<String> removeFieldNames) {
if (isLogicExpression(whereExpression)) {
AndExpression andExpression = (AndExpression) whereExpression;
Expression leftExpression = andExpression.getLeftExpression();
Expression rightExpression = andExpression.getRightExpression();
if (isLogicExpression(leftExpression)) {
removeWhereExpression(leftExpression, removeFieldNames);
} else {
removeExpressionWithConstant(leftExpression, removeFieldNames);
}
if (isLogicExpression(rightExpression)) {
removeWhereExpression(rightExpression, removeFieldNames);
} else {
removeExpressionWithConstant(rightExpression, removeFieldNames);
}
removeExpressionWithConstant(rightExpression, removeFieldNames);
} else {
removeExpressionWithConstant(whereExpression, removeFieldNames);
}
}
private static void removeExpressionWithConstant(Expression expression, Set<String> removeFieldNames) {
if (!(expression instanceof EqualsTo)) {
return;
}
ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
String columnName = "";
if (comparisonOperator.getRightExpression() instanceof Column) {
columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName();
}
if (comparisonOperator.getLeftExpression() instanceof Column) {
columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName();
}
if (!removeFieldNames.contains(columnName)) {
return;
}
try {
ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(
JsqlConstants.EQUAL_CONSTANT);
comparisonOperator.setLeftExpression(constantExpression.getLeftExpression());
comparisonOperator.setRightExpression(constantExpression.getRightExpression());
comparisonOperator.setASTNode(constantExpression.getASTNode());
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}
}
}

View File

@@ -453,6 +453,23 @@ class SqlParserUpdateHelperTest {
replaceSql);
}
@Test
void removeWhereCondition() {
String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
+ " order by 播放量 desc limit 11";
Set<String> removeFieldNames = new HashSet<>();
removeFieldNames.add("歌曲名");
String replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' "
+ "ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
}
private Map<String, String> initParams() {
Map<String, String> fieldToBizName = new HashMap<>();

View File

@@ -2,13 +2,14 @@ package com.tencent.supersonic.semantic.query.parser.calcite.sql.node;
import com.tencent.supersonic.semantic.query.parser.calcite.Configuration;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.Optimization;
import com.tencent.supersonic.semantic.query.parser.calcite.schema.SemanticSqlDialect;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.Optimization;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import org.apache.calcite.sql.SqlAsOperator;
import org.apache.calcite.sql.SqlBasicCall;
@@ -17,7 +18,6 @@ import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlWriterConfig;
import org.apache.calcite.sql.advise.SqlSimpleParser;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.pretty.SqlPrettyWriter;
@@ -41,13 +41,12 @@ public abstract class SemanticNode {
}
public static String getSql(SqlNode sqlNode) {
SqlSimpleParser sqlSimpleParser = new SqlSimpleParser("", Configuration.getParserConfig());
SqlWriterConfig config = SqlPrettyWriter.config().withDialect(SemanticSqlDialect.DEFAULT)
.withKeywordsLowerCase(true).withClauseEndsLine(true).withAlwaysUseParentheses(false)
.withSelectListItemsOnSeparateLines(false).withUpdateSetListNewline(false).withIndentation(0);
return sqlSimpleParser.simplifySql(sqlNode.toSqlString((c) -> {
return config;
}).getSql());
UnaryOperator<SqlWriterConfig> sqlWriterConfigUnaryOperator = (c) -> config;
return sqlNode.toSqlString(sqlWriterConfigUnaryOperator).getSql();
}
public static boolean isNumeric(String expr) {