diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index 23bc459be..ded4ea1d0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -1,14 +1,6 @@ package com.tencent.supersonic.common.jsqlparser; import com.tencent.supersonic.common.util.StringUtil; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Alias; @@ -50,6 +42,15 @@ import net.sf.jsqlparser.statement.select.WithItem; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + /** * Sql Parser Select Helper */ @@ -97,6 +98,22 @@ public class SqlSelectHelper { }); } + public static List gePureSelectFields(String sql) { + List plainSelectList = getPlainSelect(sql); + Set result = new HashSet<>(); + plainSelectList.stream().forEach(plainSelect -> { + List> selectItems = plainSelect.getSelectItems(); + for (SelectItem selectItem : selectItems) { + if (!(selectItem.getExpression() instanceof Column)) { + continue; + } + Column column = (Column) selectItem.getExpression(); + result.add(column.getColumnName()); + } + }); + return new ArrayList<>(result); + } + public static List getSelectFields(String sql) { List plainSelectList = getPlainSelect(sql); if (CollectionUtils.isEmpty(plainSelectList)) { diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java index d9d1fa411..fce07dca1 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java @@ -282,4 +282,23 @@ class SqlSelectHelperTest { Assert.assertEquals(tableName, "超音数"); } + @Test + void testGetPureSelectFields() { + + String sql = "select TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` " + + "where 数据日期 = '2023-08-08' and 用户 = 'alice'"; + List selectFields = SqlSelectHelper.gePureSelectFields(sql); + Assert.assertEquals(selectFields.size(), 0); + + sql = "select 发布日期,数据日期 from `超音数` where " + + "数据日期 = '2023-08-08' and 用户 = 'alice'"; + selectFields = SqlSelectHelper.gePureSelectFields(sql); + Assert.assertEquals(selectFields.size(), 2); + + sql = "select 发布日期,数据日期,TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` where " + + "数据日期 = '2023-08-08' and 用户 = 'alice'"; + selectFields = SqlSelectHelper.gePureSelectFields(sql); + Assert.assertEquals(selectFields.size(), 2); + } + } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java index f58e894d1..c971df07a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java @@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.core.env.Environment; import org.springframework.util.CollectionUtils; + import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -72,7 +73,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); //add alias field name Set dimensions = getDimensions(dataSetId, semanticSchema); - List selectFields = SqlSelectHelper.getSelectFields(correctS2SQL); + List selectFields = SqlSelectHelper.gePureSelectFields(correctS2SQL); List aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL); Set groupByFields = selectFields.stream() .filter(field -> dimensions.contains(field))