(improvement)(chat) When making corrections, the 'group by' field must not be included in the function. (#1532)

This commit is contained in:
lexluo09
2024-08-07 23:08:43 +08:00
committed by GitHub
parent b8aeff9a6a
commit 9dbc8657e2
3 changed files with 46 additions and 9 deletions

View File

@@ -1,14 +1,6 @@
package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.util.StringUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
@@ -50,6 +42,15 @@ import net.sf.jsqlparser.statement.select.WithItem;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Sql Parser Select Helper
*/
@@ -97,6 +98,22 @@ public class SqlSelectHelper {
});
}
public static List<String> gePureSelectFields(String sql) {
List<PlainSelect> plainSelectList = getPlainSelect(sql);
Set<String> result = new HashSet<>();
plainSelectList.stream().forEach(plainSelect -> {
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
if (!(selectItem.getExpression() instanceof Column)) {
continue;
}
Column column = (Column) selectItem.getExpression();
result.add(column.getColumnName());
}
});
return new ArrayList<>(result);
}
public static List<String> getSelectFields(String sql) {
List<PlainSelect> plainSelectList = getPlainSelect(sql);
if (CollectionUtils.isEmpty(plainSelectList)) {

View File

@@ -282,4 +282,23 @@ class SqlSelectHelperTest {
Assert.assertEquals(tableName, "超音数");
}
@Test
void testGetPureSelectFields() {
String sql = "select TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` "
+ "where 数据日期 = '2023-08-08' and 用户 = 'alice'";
List<String> selectFields = SqlSelectHelper.gePureSelectFields(sql);
Assert.assertEquals(selectFields.size(), 0);
sql = "select 发布日期,数据日期 from `超音数` where "
+ "数据日期 = '2023-08-08' and 用户 = 'alice'";
selectFields = SqlSelectHelper.gePureSelectFields(sql);
Assert.assertEquals(selectFields.size(), 2);
sql = "select 发布日期,数据日期,TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` where "
+ "数据日期 = '2023-08-08' and 用户 = 'alice'";
selectFields = SqlSelectHelper.gePureSelectFields(sql);
Assert.assertEquals(selectFields.size(), 2);
}
}

View File

@@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@@ -72,7 +73,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
//add alias field name
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.gePureSelectFields(correctS2SQL);
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
Set<String> groupByFields = selectFields.stream()
.filter(field -> dimensions.contains(field))