(improvement)(semantic) support multi sub-query in select (#288)

This commit is contained in:
lexluo09
2023-10-25 17:36:46 +08:00
committed by GitHub
parent eb9db28352
commit d533496b2a
3 changed files with 39 additions and 19 deletions

View File

@@ -123,34 +123,34 @@ public class SqlParserSelectHelper {
public static List<PlainSelect> getPlainSelects(PlainSelect plainSelect) { public static List<PlainSelect> getPlainSelects(PlainSelect plainSelect) {
List<PlainSelect> plainSelects = new ArrayList<>(); List<PlainSelect> plainSelects = new ArrayList<>();
plainSelects.add(plainSelect); plainSelects.add(plainSelect);
ExpressionVisitorAdapter expressionVisitor = new ExpressionVisitorAdapter() {
@Override
public void visit(SubSelect subSelect) {
SelectBody subSelectBody = subSelect.getSelectBody();
if (subSelectBody instanceof PlainSelect) {
plainSelects.add((PlainSelect) subSelectBody);
}
}
};
plainSelect.accept(new SelectVisitorAdapter() { plainSelect.accept(new SelectVisitorAdapter() {
@Override @Override
public void visit(PlainSelect plainSelect) { public void visit(PlainSelect plainSelect) {
Expression whereExpression = plainSelect.getWhere(); Expression whereExpression = plainSelect.getWhere();
if (whereExpression != null) { if (whereExpression != null) {
whereExpression.accept(new ExpressionVisitorAdapter() { whereExpression.accept(expressionVisitor);
@Override
public void visit(SubSelect subSelect) {
SelectBody subSelectBody = subSelect.getSelectBody();
if (subSelectBody instanceof PlainSelect) {
plainSelects.add((PlainSelect) subSelectBody);
}
}
});
} }
Expression having = plainSelect.getHaving(); Expression having = plainSelect.getHaving();
if (Objects.nonNull(having)) { if (Objects.nonNull(having)) {
having.accept(new ExpressionVisitorAdapter() { having.accept(expressionVisitor);
@Override }
public void visit(SubSelect subSelect) { List<SelectItem> selectItems = plainSelect.getSelectItems();
SelectBody subSelectBody = subSelect.getSelectBody(); if (!CollectionUtils.isEmpty(selectItems)) {
if (subSelectBody instanceof PlainSelect) { for (SelectItem selectItem : selectItems) {
plainSelects.add((PlainSelect) subSelectBody); selectItem.accept(expressionVisitor);
} }
}
});
} }
} }
}); });
return plainSelects; return plainSelects;

View File

@@ -295,6 +295,17 @@ class SqlParserReplaceHelperTest {
"SELECT song_name, sum(user_id) FROM CSpider WHERE (1 < 2) AND " "SELECT song_name, sum(user_id) FROM CSpider WHERE (1 < 2) AND "
+ "sys_imp_date = '2023-10-15' GROUP BY song_name HAVING " + "sys_imp_date = '2023-10-15' GROUP BY song_name HAVING "
+ "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql); + "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql);
replaceSql = "SELECT sum(评分)/ (SELECT sum(评分) FROM CSpider WHERE 数据日期 = '2023-10-15')"
+ " FROM CSpider WHERE 数据日期 = '2023-10-15' "
+ "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider WHERE 语种 = '英文')";
replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName);
replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT sum(user_id) / (SELECT sum(user_id) FROM CSpider WHERE sys_imp_date = '2023-10-15') "
+ "FROM CSpider WHERE sys_imp_date = '2023-10-15' GROUP BY song_name HAVING "
+ "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql);
} }

View File

@@ -108,6 +108,11 @@ class SqlParserSelectHelperTest {
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
System.out.println(filterExpression); System.out.println(filterExpression);
filterExpression = SqlParserSelectHelper.getFilterExpression("SELECT sum(销量) / (SELECT sum(销量) FROM 营销月模型 "
+ "WHERE MONTH(数据日期) = 9) FROM 营销月模型 WHERE 国家中文名 = '肯尼亚' AND MONTH(数据日期) = 9");
System.out.println(filterExpression);
} }
@@ -157,6 +162,10 @@ class SqlParserSelectHelperTest {
Assert.assertEquals(allFields.size(), 3); Assert.assertEquals(allFields.size(), 3);
allFields = SqlParserSelectHelper.getAllFields("SELECT sum(销量) / (SELECT sum(销量) FROM 营销 "
+ "WHERE MONTH(数据日期) = 9) FROM 营销 WHERE 国家中文名 = '中国' AND MONTH(数据日期) = 9");
Assert.assertEquals(allFields.size(), 3);
} }