mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(improvement)(Headless) distinct select fields in S2CorrectSQL (#912)
This commit is contained in:
@@ -5,6 +5,7 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
@@ -12,6 +13,7 @@ import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.Parenthesis;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
@@ -225,7 +227,7 @@ public class SqlAddHelper {
|
||||
}
|
||||
|
||||
private static void addAggregateToSelectItems(List<SelectItem> selectItems,
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
for (SelectItem selectItem : selectItems) {
|
||||
if (selectItem instanceof SelectExpressionItem) {
|
||||
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
|
||||
@@ -240,7 +242,7 @@ public class SqlAddHelper {
|
||||
}
|
||||
|
||||
private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements,
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
if (orderByElements == null) {
|
||||
return;
|
||||
}
|
||||
@@ -255,7 +257,7 @@ public class SqlAddHelper {
|
||||
}
|
||||
|
||||
private static void addAggregateToGroupByItems(GroupByElement groupByElement,
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
if (groupByElement == null) {
|
||||
return;
|
||||
}
|
||||
@@ -276,13 +278,22 @@ public class SqlAddHelper {
|
||||
}
|
||||
|
||||
private static void modifyWhereExpression(Expression whereExpression,
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
if (SqlSelectHelper.isLogicExpression(whereExpression)) {
|
||||
AndExpression andExpression = (AndExpression) whereExpression;
|
||||
Expression leftExpression = andExpression.getLeftExpression();
|
||||
Expression rightExpression = andExpression.getRightExpression();
|
||||
modifyWhereExpression(leftExpression, fieldNameToAggregate);
|
||||
modifyWhereExpression(rightExpression, fieldNameToAggregate);
|
||||
if (whereExpression instanceof AndExpression) {
|
||||
AndExpression andExpression = (AndExpression) whereExpression;
|
||||
Expression leftExpression = andExpression.getLeftExpression();
|
||||
Expression rightExpression = andExpression.getRightExpression();
|
||||
modifyWhereExpression(leftExpression, fieldNameToAggregate);
|
||||
modifyWhereExpression(rightExpression, fieldNameToAggregate);
|
||||
}
|
||||
if (whereExpression instanceof OrExpression) {
|
||||
OrExpression orExpression = (OrExpression) whereExpression;
|
||||
Expression leftExpression = orExpression.getLeftExpression();
|
||||
Expression rightExpression = orExpression.getRightExpression();
|
||||
modifyWhereExpression(leftExpression, fieldNameToAggregate);
|
||||
modifyWhereExpression(rightExpression, fieldNameToAggregate);
|
||||
}
|
||||
} else if (whereExpression instanceof Parenthesis) {
|
||||
modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate);
|
||||
} else {
|
||||
|
||||
@@ -29,6 +29,7 @@ import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.Objects;
|
||||
@@ -60,6 +61,32 @@ public class SqlRemoveHelper {
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
public static String removeSameFieldFromSelect(String sql) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
if (selectStatement == null) {
|
||||
return sql;
|
||||
}
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
List<SelectItem> selectItems = ((PlainSelect) selectBody).getSelectItems();
|
||||
Set<String> fields = new HashSet<>();
|
||||
selectItems.removeIf(selectItem -> {
|
||||
if (selectItem instanceof SelectExpressionItem) {
|
||||
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
|
||||
String field = selectExpressionItem.getExpression().toString();
|
||||
if (fields.contains(field)) {
|
||||
return true;
|
||||
}
|
||||
fields.add(field);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
((PlainSelect) selectBody).setSelectItems(selectItems);
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
public static String removeWhereCondition(String sql, Set<String> removeFieldNames) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -10,6 +11,21 @@ import org.junit.jupiter.api.Test;
|
||||
*/
|
||||
class SqlRemoveHelperTest {
|
||||
|
||||
@Test
|
||||
void testRemoveSameFieldFromSelect() {
|
||||
String sql = "select 歌曲名,歌手名,粉丝数,粉丝数,sum(粉丝数),sum(粉丝数),avg(播放量),avg(播放量)"
|
||||
+ " from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "
|
||||
+ "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1";
|
||||
sql = SqlRemoveHelper.removeSameFieldFromSelect(sql);
|
||||
System.out.println(sql);
|
||||
sql = "SELECT 结算播放量 FROM 艺人 WHERE (歌手名 IN ('林俊杰', '陈奕迅')) AND (数据日期 >= '2024-04-04' AND 数据日期 <= '2024-04-04')";
|
||||
List<FieldExpression> fieldExpressionList = SqlSelectHelper.getWhereExpressions(sql);
|
||||
fieldExpressionList.stream().forEach(fieldExpression -> {
|
||||
System.out.println(fieldExpression.toString());
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRemoveWhereHavingCondition() {
|
||||
String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "
|
||||
|
||||
Reference in New Issue
Block a user