From 3b1cbd4fd76c2bb736f241d63e17a43ca18c6d32 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Tue, 10 Oct 2023 17:42:19 +0800 Subject: [PATCH] (improvement)(chat) group by corrector remove aggregate fields (#186) --- .../chat/corrector/GroupByCorrector.java | 10 +++++++- .../jsqlparser/SqlParserSelectHelper.java | 24 +++++++++++++++++++ .../jsqlparser/SqlParserSelectHelperTest.java | 10 ++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java index 8e715f4e5..87c4328de 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -47,7 +47,15 @@ public class GroupByCorrector extends BaseSemanticCorrector { if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) { return; } - Set groupByFields = selectFields.stream().filter(field -> dimensions.contains(field)) + List aggregateFields = SqlParserSelectHelper.getAggregateFields(sql); + Set groupByFields = selectFields.stream() + .filter(field -> dimensions.contains(field)) + .filter(field -> { + if (!CollectionUtils.isEmpty(aggregateFields) && aggregateFields.contains(field)) { + return false; + } + return true; + }) .collect(Collectors.toSet()); semanticCorrectInfo.setSql(SqlParserUpdateHelper.addGroupBy(sql, groupByFields)); } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index 1c49b85bf..cce8972bb 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -19,6 +19,7 @@ import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; import org.springframework.util.CollectionUtils; @@ -207,6 +208,29 @@ public class SqlParserSelectHelper { return table.getName(); } + public static List getAggregateFields(String sql) { + PlainSelect plainSelect = getPlainSelect(sql); + if (Objects.isNull(plainSelect)) { + return new ArrayList<>(); + } + Set result = new HashSet<>(); + List selectItems = plainSelect.getSelectItems(); + for (SelectItem selectItem : selectItems) { + if (selectItem instanceof SelectExpressionItem) { + SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem; + if (expressionItem.getExpression() instanceof Function) { + Function function = (Function) expressionItem.getExpression(); + if (Objects.nonNull(function.getParameters()) + && !CollectionUtils.isEmpty(function.getParameters().getExpressions())) { + String columnName = function.getParameters().getExpressions().get(0).toString(); + result.add(columnName); + } + } + } + } + return new ArrayList<>(result); + } + public static boolean hasAggregateFunction(String sql) { if (hasFunction(sql)) { diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index 5fd55974c..e5c24e4fb 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -261,4 +261,14 @@ class SqlParserSelectHelperTest { } + @Test + void getAggregateFields() { + + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; + List selectFields = SqlParserSelectHelper.getAggregateFields(sql); + Assert.assertEquals(selectFields.contains("访问次数"), true); + + } + }