[improvement][headless] fixed error sql order item is the same as the agg alias (#1016)

This commit is contained in:
jipeli
2024-05-20 19:37:56 +08:00
committed by GitHub
parent d513b6d2cc
commit 40dc5e2607
3 changed files with 64 additions and 4 deletions

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.util.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.StringUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -13,6 +14,7 @@ import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
@@ -515,6 +517,50 @@ public class SqlReplaceHelper {
}
}
private static Select replaceAggAliasOrderItem(Select selectStatement) {
if (selectStatement instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectStatement;
if (Objects.nonNull(plainSelect.getOrderByElements())) {
Map<String, String> selectNames = new HashMap<>();
for (int i = 0; i < plainSelect.getSelectItems().size(); i++) {
SelectItem<?> f = plainSelect.getSelectItem(i);
if (Objects.nonNull(f.getAlias()) && f.getExpression() instanceof Function) {
Function function = (Function) f.getExpression();
String alias = f.getAlias().getName();
if (function.getParameters().size() == 1 && function.getParameters().get(0) instanceof Column) {
Column column = (Column) function.getParameters().get(0);
if (column.getColumnName().equalsIgnoreCase(alias)) {
selectNames.put(alias, String.valueOf(i + 1));
}
}
}
}
plainSelect.getOrderByElements().stream().forEach(o -> {
if (o.getExpression() instanceof Function) {
Function function = (Function) o.getExpression();
if (function.getParameters().size() == 1 && function.getParameters().get(0) instanceof Column) {
Column column = (Column) function.getParameters().get(0);
if (selectNames.containsKey(column.getColumnName())) {
o.setExpression(new LongValue(selectNames.get(column.getColumnName())));
}
}
}
});
}
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
parenthesedSelect.setSelect(replaceAggAliasOrderItem(parenthesedSelect.getSelect()));
}
return selectStatement;
}
return selectStatement;
}
public static String replaceAggAliasOrderItem(String sql) {
Select selectStatement = replaceAggAliasOrderItem(SqlSelectHelper.getSelect(sql));
return selectStatement.toString();
}
public static String replaceExpression(String expr, Map<String, String> replace) {
Expression expression = QueryExpressionReplaceVisitor.getExpression(expr);
if (Objects.nonNull(expression)) {

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.common.util.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
/**
* SqlParserReplaceHelperTest
@@ -479,6 +478,17 @@ class SqlReplaceHelperTest {
}
@Test
void testReplaceAggAliasOrderItem() {
String sql = "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 "
+ "GROUP BY 部门 ORDER BY SUM(访问次数) DESC LIMIT 10) AS top10";
String replaceSql = SqlReplaceHelper.replaceAggAliasOrderItem(sql);
Assert.assertEquals(
"SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 "
+ "GROUP BY 部门 ORDER BY 2 DESC LIMIT 10) AS top10",
replaceSql);
}
private Map<String, String> initParams() {
Map<String, String> fieldToBizName = new HashMap<>();
fieldToBizName.put("部门", "department");