(improvement)(chat) dsl corrector support add agg metric in having (#95)

This commit is contained in:
lexluo09
2023-09-15 17:54:51 +08:00
committed by GitHub
parent 3701ade05f
commit a87304b22b
7 changed files with 112 additions and 1 deletions

View File

@@ -6,8 +6,10 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
@Slf4j
@@ -17,6 +19,12 @@ public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(preSql, havingExpression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(replaceSql);
}
return;
}
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql));

View File

@@ -22,5 +22,25 @@ class SelectFieldAppendCorrectorTest {
+ "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'"
+ " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql("select 用户名 from 内容库产品 where datediff('day', 数据日期, '2023-09-14') <= 30"
+ " group by 用户名 having sum(访问次数) > 2000");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql("SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
}
}

View File

@@ -8,6 +8,8 @@ import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
@@ -130,6 +132,23 @@ public class SqlParserSelectHelper {
}
public static Expression getHavingExpression(String sql) {
PlainSelect plainSelect = getPlainSelect(sql);
Expression having = plainSelect.getHaving();
if (Objects.nonNull(having)) {
if (!(having instanceof ComparisonOperator)) {
return null;
}
ComparisonOperator comparisonOperator = (ComparisonOperator) having;
if (comparisonOperator.getLeftExpression() instanceof Function) {
return comparisonOperator.getLeftExpression();
} else if (comparisonOperator.getRightExpression() instanceof Function) {
return comparisonOperator.getRightExpression();
}
}
return null;
}
public static List<String> getOrderByFields(String sql) {
PlainSelect plainSelect = getPlainSelect(sql);
if (Objects.isNull(plainSelect)) {

View File

@@ -6,6 +6,7 @@ import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
@@ -17,6 +18,7 @@ import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.util.SelectUtils;
import org.apache.commons.lang3.StringUtils;
@@ -172,6 +174,33 @@ public class SqlParserUpdateHelper {
return selectStatement.toString();
}
public static String addFunctionToSelect(String sql, Expression expression) {
PlainSelect plainSelect = SqlParserSelectHelper.getPlainSelect(sql);
if (Objects.isNull(plainSelect)) {
return sql;
}
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isEmpty(selectItems)) {
return sql;
}
boolean existFunction = false;
for (SelectItem selectItem : selectItems) {
SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem;
if (expressionItem.getExpression() instanceof Function) {
Function expressionFunction = (Function) expressionItem.getExpression();
if (expression.toString().equalsIgnoreCase(expressionFunction.toString())) {
existFunction = true;
break;
}
}
}
if (!existFunction) {
SelectExpressionItem sumExpressionItem = new SelectExpressionItem(expression);
selectItems.add(sumExpressionItem);
}
return plainSelect.toString();
}
public static String replaceTable(String sql, String tableName) {
if (StringUtils.isEmpty(tableName)) {
return sql;

View File

@@ -250,4 +250,15 @@ class SqlParserSelectHelperTest {
}
@Test
void getHavingExpression() {
String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND "
+ "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
Expression leftExpression = SqlParserSelectHelper.getHavingExpression(sql);
Assert.assertEquals(leftExpression.toString(), "sum(pv)");
}
}

View File

@@ -242,6 +242,30 @@ class SqlParserUpdateHelperTest {
}
@Test
void addFunctionToSelect() {
String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND "
+ "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
System.out.println(replaceSql);
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' "
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
replaceSql);
sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND "
+ "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
System.out.println(replaceSql);
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' "
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
replaceSql);
}
private Map<String, String> initParams() {
Map<String, String> fieldToBizName = new HashMap<>();
fieldToBizName.put("部门", "department");

View File

@@ -64,7 +64,7 @@
<xk.time.version>3.2.4</xk.time.version>
<mockito-inline.version>4.5.1</mockito-inline.version>
<jsqlparser.version>4.5</jsqlparser.version>
<revision>0.7.4</revision>
<revision>0.7.5-SNAPSHOT</revision>
</properties>
<dependencyManagement>