(improvement)(chat) group by corrector remove aggregate fields (#186)

This commit is contained in:
lexluo09
2023-10-10 17:42:19 +08:00
committed by GitHub
parent 500652da36
commit 3b1cbd4fd7
3 changed files with 43 additions and 1 deletions

View File

@@ -19,6 +19,7 @@ import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItem;
import org.springframework.util.CollectionUtils;
@@ -207,6 +208,29 @@ public class SqlParserSelectHelper {
return table.getName();
}
public static List<String> getAggregateFields(String sql) {
PlainSelect plainSelect = getPlainSelect(sql);
if (Objects.isNull(plainSelect)) {
return new ArrayList<>();
}
Set<String> result = new HashSet<>();
List<SelectItem> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem;
if (expressionItem.getExpression() instanceof Function) {
Function function = (Function) expressionItem.getExpression();
if (Objects.nonNull(function.getParameters())
&& !CollectionUtils.isEmpty(function.getParameters().getExpressions())) {
String columnName = function.getParameters().getExpressions().get(0).toString();
result.add(columnName);
}
}
}
}
return new ArrayList<>(result);
}
public static boolean hasAggregateFunction(String sql) {
if (hasFunction(sql)) {

View File

@@ -261,4 +261,14 @@ class SqlParserSelectHelperTest {
}
@Test
void getAggregateFields() {
String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'"
+ " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1";
List<String> selectFields = SqlParserSelectHelper.getAggregateFields(sql);
Assert.assertEquals(selectFields.contains("访问次数"), true);
}
}