mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(chat) When making corrections, the 'group by' field must not be included in the function. (#1532)
This commit is contained in:
@@ -1,14 +1,6 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Alias;
|
import net.sf.jsqlparser.expression.Alias;
|
||||||
@@ -50,6 +42,15 @@ import net.sf.jsqlparser.statement.select.WithItem;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sql Parser Select Helper
|
* Sql Parser Select Helper
|
||||||
*/
|
*/
|
||||||
@@ -97,6 +98,22 @@ public class SqlSelectHelper {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static List<String> gePureSelectFields(String sql) {
|
||||||
|
List<PlainSelect> plainSelectList = getPlainSelect(sql);
|
||||||
|
Set<String> result = new HashSet<>();
|
||||||
|
plainSelectList.stream().forEach(plainSelect -> {
|
||||||
|
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
|
||||||
|
for (SelectItem selectItem : selectItems) {
|
||||||
|
if (!(selectItem.getExpression() instanceof Column)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Column column = (Column) selectItem.getExpression();
|
||||||
|
result.add(column.getColumnName());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return new ArrayList<>(result);
|
||||||
|
}
|
||||||
|
|
||||||
public static List<String> getSelectFields(String sql) {
|
public static List<String> getSelectFields(String sql) {
|
||||||
List<PlainSelect> plainSelectList = getPlainSelect(sql);
|
List<PlainSelect> plainSelectList = getPlainSelect(sql);
|
||||||
if (CollectionUtils.isEmpty(plainSelectList)) {
|
if (CollectionUtils.isEmpty(plainSelectList)) {
|
||||||
|
|||||||
@@ -282,4 +282,23 @@ class SqlSelectHelperTest {
|
|||||||
Assert.assertEquals(tableName, "超音数");
|
Assert.assertEquals(tableName, "超音数");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testGetPureSelectFields() {
|
||||||
|
|
||||||
|
String sql = "select TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` "
|
||||||
|
+ "where 数据日期 = '2023-08-08' and 用户 = 'alice'";
|
||||||
|
List<String> selectFields = SqlSelectHelper.gePureSelectFields(sql);
|
||||||
|
Assert.assertEquals(selectFields.size(), 0);
|
||||||
|
|
||||||
|
sql = "select 发布日期,数据日期 from `超音数` where "
|
||||||
|
+ "数据日期 = '2023-08-08' and 用户 = 'alice'";
|
||||||
|
selectFields = SqlSelectHelper.gePureSelectFields(sql);
|
||||||
|
Assert.assertEquals(selectFields.size(), 2);
|
||||||
|
|
||||||
|
sql = "select 发布日期,数据日期,TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` where "
|
||||||
|
+ "数据日期 = '2023-08-08' and 用户 = 'alice'";
|
||||||
|
selectFields = SqlSelectHelper.gePureSelectFields(sql);
|
||||||
|
Assert.assertEquals(selectFields.size(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.core.env.Environment;
|
import org.springframework.core.env.Environment;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
@@ -72,7 +73,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
//add alias field name
|
//add alias field name
|
||||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
List<String> selectFields = SqlSelectHelper.gePureSelectFields(correctS2SQL);
|
||||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
Set<String> groupByFields = selectFields.stream()
|
Set<String> groupByFields = selectFields.stream()
|
||||||
.filter(field -> dimensions.contains(field))
|
.filter(field -> dimensions.contains(field))
|
||||||
|
|||||||
Reference in New Issue
Block a user