mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(common) support addAggregateToField and addGroupBy and convert metricFilter to Having (#140)
This commit is contained in:
@@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -32,6 +33,11 @@ public class GlobalCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
if (SqlParserSelectHelper.hasGroupBy(semanticCorrectInfo.getSql())) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class FiledExpression {
|
||||
|
||||
private String operator;
|
||||
|
||||
private String fieldName;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
||||
private Set<String> fieldNames;
|
||||
|
||||
public FiledFilterReplaceVisitor(Set<String> fieldNames) {
|
||||
this.fieldNames = fieldNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(MinorThan expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(EqualsTo expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(MinorThanEquals expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThan expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThanEquals expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
public List<Expression> getWaitingForAdds() {
|
||||
return waitingForAdds;
|
||||
}
|
||||
|
||||
|
||||
public List<Expression> parserFilter(ComparisonOperator comparisonOperator) {
|
||||
List<Expression> result = new ArrayList<>();
|
||||
String toString = comparisonOperator.toString();
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
if (!(leftExpression instanceof Function)) {
|
||||
return result;
|
||||
}
|
||||
Function leftExpressionFunction = (Function) leftExpression;
|
||||
if (leftExpressionFunction.toString().contains(DateFunctionHelper.DATE_FUNCTION)) {
|
||||
return result;
|
||||
}
|
||||
|
||||
List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
|
||||
if (CollectionUtils.isEmpty(leftExpressions)) {
|
||||
return result;
|
||||
}
|
||||
Column field = (Column) leftExpressions.get(0);
|
||||
String columnName = field.getColumnName();
|
||||
if (!fieldNames.contains(columnName)) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 ");
|
||||
comparisonOperator.setLeftExpression(expression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(expression.getRightExpression());
|
||||
comparisonOperator.setASTNode(expression.getASTNode());
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(toString));
|
||||
return result;
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -205,6 +205,30 @@ public class SqlParserSelectHelper {
|
||||
|
||||
|
||||
public static boolean hasAggregateFunction(String sql) {
|
||||
if (hasFunction(sql)) {
|
||||
return true;
|
||||
}
|
||||
return hasGroupBy(sql);
|
||||
}
|
||||
|
||||
public static boolean hasGroupBy(String sql) {
|
||||
Select selectStatement = getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return false;
|
||||
}
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
GroupByElement groupBy = plainSelect.getGroupBy();
|
||||
if (Objects.nonNull(groupBy)) {
|
||||
GroupByVisitor replaceVisitor = new GroupByVisitor();
|
||||
groupBy.accept(replaceVisitor);
|
||||
return replaceVisitor.isHasAggregateFunction();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static boolean hasFunction(String sql) {
|
||||
Select selectStatement = getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
@@ -221,12 +245,6 @@ public class SqlParserSelectHelper {
|
||||
if (selectFunction) {
|
||||
return true;
|
||||
}
|
||||
GroupByElement groupBy = plainSelect.getGroupBy();
|
||||
if (Objects.nonNull(groupBy)) {
|
||||
GroupByVisitor replaceVisitor = new GroupByVisitor();
|
||||
groupBy.accept(replaceVisitor);
|
||||
return replaceVisitor.isHasAggregateFunction();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||
@@ -20,6 +21,7 @@ import net.sf.jsqlparser.statement.select.Select;
|
||||
import net.sf.jsqlparser.statement.select.SelectBody;
|
||||
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||
import net.sf.jsqlparser.util.SelectUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -278,5 +280,114 @@ public class SqlParserUpdateHelper {
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
public static String addAggregateToField(String sql, Map<String, String> fieldNameToAggregate) {
|
||||
if (SqlParserSelectHelper.hasGroupBy(sql)) {
|
||||
return sql;
|
||||
}
|
||||
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
selectBody.accept(new SelectVisitorAdapter() {
|
||||
@Override
|
||||
public void visit(PlainSelect plainSelect) {
|
||||
addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate);
|
||||
addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate);
|
||||
}
|
||||
});
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
public static String addGroupBy(String sql, List<String> groupByFields) {
|
||||
if (SqlParserSelectHelper.hasGroupBy(sql)) {
|
||||
return sql;
|
||||
}
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
GroupByElement groupByElement = new GroupByElement();
|
||||
for (String groupByField : groupByFields) {
|
||||
groupByElement.addGroupByExpression(new Column(groupByField));
|
||||
}
|
||||
plainSelect.setGroupByElement(groupByElement);
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
private static void addAggregateToSelectItems(List<SelectItem> selectItems,
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
for (SelectItem selectItem : selectItems) {
|
||||
if (selectItem instanceof SelectExpressionItem) {
|
||||
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
|
||||
Expression expression = selectExpressionItem.getExpression();
|
||||
String columnName = ((Column) expression).getColumnName();
|
||||
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
|
||||
if (Objects.isNull(function)) {
|
||||
continue;
|
||||
}
|
||||
selectExpressionItem.setExpression(function);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements,
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
if (orderByElements == null) {
|
||||
return;
|
||||
}
|
||||
for (OrderByElement orderByElement : orderByElements) {
|
||||
Expression expression = orderByElement.getExpression();
|
||||
String columnName = ((Column) expression).getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
continue;
|
||||
}
|
||||
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
|
||||
if (Objects.isNull(function)) {
|
||||
continue;
|
||||
}
|
||||
orderByElement.setExpression(function);
|
||||
}
|
||||
}
|
||||
|
||||
private static Function getFunction(Expression expression, String aggregateName) {
|
||||
if (StringUtils.isEmpty(aggregateName)) {
|
||||
return null;
|
||||
}
|
||||
Function sumFunction = new Function();
|
||||
sumFunction.setName(aggregateName);
|
||||
sumFunction.setParameters(new ExpressionList(expression));
|
||||
return sumFunction;
|
||||
}
|
||||
|
||||
public static String addHaving(String sql, Set<String> fieldNames) {
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
//replace metric to 1 and 1 and add having metric
|
||||
Expression where = plainSelect.getWhere();
|
||||
FiledFilterReplaceVisitor visitor = new FiledFilterReplaceVisitor(fieldNames);
|
||||
if (Objects.nonNull(where)) {
|
||||
where.accept(visitor);
|
||||
}
|
||||
List<Expression> waitingForAdds = visitor.getWaitingForAdds();
|
||||
if (!CollectionUtils.isEmpty(waitingForAdds)) {
|
||||
for (Expression waitingForAdd : waitingForAdds) {
|
||||
plainSelect.setHaving(waitingForAdd);
|
||||
}
|
||||
}
|
||||
return selectStatement.toString();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
@@ -266,6 +269,85 @@ class SqlParserUpdateHelperTest {
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void addAggregateToField() {
|
||||
String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND "
|
||||
+ "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
|
||||
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
|
||||
System.out.println(replaceSql);
|
||||
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' "
|
||||
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
|
||||
replaceSql);
|
||||
|
||||
sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND "
|
||||
+ "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
|
||||
havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
|
||||
System.out.println(replaceSql);
|
||||
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' "
|
||||
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
|
||||
replaceSql);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void addAggregateToMetricField() {
|
||||
String sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' order by pv desc limit 10";
|
||||
|
||||
Map<String, String> filedNameToAggregate = new HashMap<>();
|
||||
filedNameToAggregate.put("pv", "sum");
|
||||
|
||||
List<String> groupByFields = new ArrayList<>();
|
||||
groupByFields.add("department");
|
||||
|
||||
String replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void addGroupBy() {
|
||||
String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' "
|
||||
+ "order by sum(pv) desc limit 10";
|
||||
|
||||
List<String> groupByFields = new ArrayList<>();
|
||||
groupByFields.add("department");
|
||||
|
||||
String replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void addHaving() {
|
||||
String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and "
|
||||
+ "sum(pv) > 2000 group by department order by sum(pv) desc limit 10";
|
||||
List<String> groupByFields = new ArrayList<>();
|
||||
groupByFields.add("department");
|
||||
|
||||
Set<String> fieldNames = new HashSet<>();
|
||||
fieldNames.add("pv");
|
||||
|
||||
String replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "AND 1 > 1 GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
|
||||
private Map<String, String> initParams() {
|
||||
Map<String, String> fieldToBizName = new HashMap<>();
|
||||
fieldToBizName.put("部门", "department");
|
||||
|
||||
Reference in New Issue
Block a user