(improvement)(chat) remove group by dimension and add FunctionAliasReplaceVisitor in dsl (#77)

* (improvement)(chat) remove group by dimension in join case

* (improvement)(chat) add FunctionAliasReplaceVisitor in dsl

---------
This commit is contained in:
lexluo09
2023-09-12 17:04:21 +08:00
committed by jerryjzhang
parent c6b87d30a5
commit 2c621a1338
10 changed files with 115 additions and 34 deletions

View File

@@ -10,13 +10,15 @@ public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName;
private boolean exactReplace;
public FieldReplaceVisitor(Map<String, String> fieldToBizName) {
public FieldReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) {
this.fieldToBizName = fieldToBizName;
this.exactReplace = exactReplace;
}
@Override
public void visit(Column column) {
parseVisitorHelper.replaceColumn(column, fieldToBizName);
parseVisitorHelper.replaceColumn(column, fieldToBizName, exactReplace);
}
}

View File

@@ -0,0 +1,28 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter;
public class FunctionAliasReplaceVisitor extends SelectItemVisitorAdapter {
private Map<String, String> aliasToActualExpression = new HashMap<>();
@Override
public void visit(SelectExpressionItem selectExpressionItem) {
if (selectExpressionItem.getExpression() instanceof Function) {
Function function = (Function) selectExpressionItem.getExpression();
if (Objects.nonNull(selectExpressionItem.getAlias())) {
aliasToActualExpression.put(selectExpressionItem.getAlias().getName(), function.toString());
selectExpressionItem.setAlias(null);
}
}
}
public Map<String, String> getAliasToActualExpression() {
return aliasToActualExpression;
}
}

View File

@@ -18,10 +18,11 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName;
private boolean exactReplace;
public GroupByReplaceVisitor(Map<String, String> fieldToBizName) {
public GroupByReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) {
this.fieldToBizName = fieldToBizName;
this.exactReplace = exactReplace;
}
public void visit(GroupByElement groupByElement) {
@@ -32,7 +33,8 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
for (int i = 0; i < groupByExpressions.size(); i++) {
Expression expression = groupByExpressions.get(i);
String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName);
String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName,
exactReplace);
if (StringUtils.isNotEmpty(replaceColumn)) {
if (expression instanceof Column) {
groupByExpressions.set(i, new Column(replaceColumn));

View File

@@ -12,23 +12,25 @@ public class OrderByReplaceVisitor extends OrderByVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName;
private boolean exactReplace;
public OrderByReplaceVisitor(Map<String, String> fieldToBizName) {
public OrderByReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) {
this.fieldToBizName = fieldToBizName;
this.exactReplace = exactReplace;
}
@Override
public void visit(OrderByElement orderBy) {
Expression expression = orderBy.getExpression();
if (expression instanceof Column) {
parseVisitorHelper.replaceColumn((Column) expression, fieldToBizName);
parseVisitorHelper.replaceColumn((Column) expression, fieldToBizName, exactReplace);
}
if (expression instanceof Function) {
Function function = (Function) expression;
List<Expression> expressions = function.getParameters().getExpressions();
for (Expression column : expressions) {
if (column instanceof Column) {
parseVisitorHelper.replaceColumn((Column) column, fieldToBizName);
parseVisitorHelper.replaceColumn((Column) column, fieldToBizName, exactReplace);
}
}
}

View File

@@ -11,27 +11,32 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j
public class ParseVisitorHelper {
public void replaceColumn(Column column, Map<String, String> fieldToBizName) {
public void replaceColumn(Column column, Map<String, String> fieldToBizName, boolean exactReplace) {
String columnName = column.getColumnName();
column.setColumnName(getReplaceColumn(columnName, fieldToBizName));
String replaceColumn = getReplaceColumn(columnName, fieldToBizName, exactReplace);
if (StringUtils.isNotBlank(replaceColumn)) {
column.setColumnName(replaceColumn);
}
}
public String getReplaceColumn(String columnName, Map<String, String> fieldToBizName) {
public String getReplaceColumn(String columnName, Map<String, String> fieldToBizName, boolean exactReplace) {
String fieldBizName = fieldToBizName.get(columnName);
if (StringUtils.isNotEmpty(fieldBizName)) {
if (StringUtils.isNotBlank(fieldBizName)) {
return fieldBizName;
} else {
Optional<Entry<String, String>> first = fieldToBizName.entrySet().stream().sorted((k1, k2) -> {
String k1FieldNameDb = k1.getKey();
String k2FieldNameDb = k2.getKey();
Double k1Similarity = getSimilarity(columnName, k1FieldNameDb);
Double k2Similarity = getSimilarity(columnName, k2FieldNameDb);
return k2Similarity.compareTo(k1Similarity);
}).collect(Collectors.toList()).stream().findFirst();
}
if (exactReplace) {
return null;
}
Optional<Entry<String, String>> first = fieldToBizName.entrySet().stream().sorted((k1, k2) -> {
String k1FieldNameDb = k1.getKey();
String k2FieldNameDb = k2.getKey();
Double k1Similarity = getSimilarity(columnName, k1FieldNameDb);
Double k2Similarity = getSimilarity(columnName, k2FieldNameDb);
return k2Similarity.compareTo(k1Similarity);
}).collect(Collectors.toList()).stream().findFirst();
if (first.isPresent()) {
return first.get().getValue();
}
if (first.isPresent()) {
return first.get().getValue();
}
return columnName;
}

View File

@@ -58,8 +58,11 @@ public class SqlParserUpdateHelper {
return selectStatement.toString();
}
public static String replaceFields(String sql, Map<String, String> fieldToBizName) {
return replaceFields(sql, fieldToBizName, false);
}
public static String replaceFields(String sql, Map<String, String> fieldToBizName, boolean exactReplace) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
@@ -68,7 +71,7 @@ public class SqlParserUpdateHelper {
PlainSelect plainSelect = (PlainSelect) selectBody;
//1. replace where fields
Expression where = plainSelect.getWhere();
FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldToBizName);
FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldToBizName, exactReplace);
if (Objects.nonNull(where)) {
where.accept(visitor);
}
@@ -82,14 +85,14 @@ public class SqlParserUpdateHelper {
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
if (!CollectionUtils.isEmpty(orderByElements)) {
for (OrderByElement orderByElement : orderByElements) {
orderByElement.accept(new OrderByReplaceVisitor(fieldToBizName));
orderByElement.accept(new OrderByReplaceVisitor(fieldToBizName, exactReplace));
}
}
//4. replace group by fields
GroupByElement groupByElement = plainSelect.getGroupBy();
if (Objects.nonNull(groupByElement)) {
groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName));
groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName, exactReplace));
}
return selectStatement.toString();
}
@@ -178,6 +181,24 @@ public class SqlParserUpdateHelper {
}
public static String replaceAlias(String sql) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectBody;
FunctionAliasReplaceVisitor visitor = new FunctionAliasReplaceVisitor();
for (SelectItem selectItem : plainSelect.getSelectItems()) {
selectItem.accept(visitor);
}
Map<String, String> aliasToActualExpression = visitor.getAliasToActualExpression();
if (Objects.nonNull(aliasToActualExpression) && !aliasToActualExpression.isEmpty()) {
return replaceFields(selectStatement.toString(), aliasToActualExpression, true);
}
return selectStatement.toString();
}
public static String addWhere(String sql, String column, Object value) {
if (StringUtils.isEmpty(column) || Objects.isNull(value)) {
return sql;

View File

@@ -217,6 +217,19 @@ class SqlParserUpdateHelperTest {
replaceSql);
}
@Test
void replaceAlias() {
String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
+ "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10";
String replaceSql = SqlParserUpdateHelper.replaceAlias(sql);
System.out.println(replaceSql);
Assert.assertEquals(
"SELECT 部门, sum(访问次数) FROM 超音数 WHERE "
+ "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10",
replaceSql);
}
private Map<String, String> initParams() {
Map<String, String> fieldToBizName = new HashMap<>();
fieldToBizName.put("部门", "department");

View File

@@ -32,6 +32,7 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.DateFieldCorrector, \
com.tencent.supersonic.chat.corrector.FunctionAliasReplaceVisitor, \
com.tencent.supersonic.chat.corrector.FieldNameCorrector, \
com.tencent.supersonic.chat.corrector.FieldCorrector, \
com.tencent.supersonic.chat.corrector.FunctionCorrector, \

View File

@@ -32,6 +32,7 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.DateFieldCorrector, \
com.tencent.supersonic.chat.corrector.FunctionAliasReplaceVisitor, \
com.tencent.supersonic.chat.corrector.FieldNameCorrector, \
com.tencent.supersonic.chat.corrector.FieldCorrector, \
com.tencent.supersonic.chat.corrector.FunctionCorrector, \

View File

@@ -1,18 +1,18 @@
package com.tencent.supersonic.semantic.query.parser.calcite.sql.render;
import com.tencent.supersonic.semantic.api.query.request.MetricReq;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.Renderer;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.AggFunctionNode;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.DataSourceNode;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.FilterNode;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.MetricNode;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.TableView;
import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Constants;
import com.tencent.supersonic.semantic.query.parser.calcite.dsl.DataSource;
import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Dimension;
import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Identify;
import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Metric;
import com.tencent.supersonic.semantic.query.parser.calcite.schema.SemanticSchema;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.Renderer;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.TableView;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.AggFunctionNode;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.DataSourceNode;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.FilterNode;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.MetricNode;
import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.SemanticNode;
import java.util.ArrayList;
import java.util.Arrays;
@@ -118,7 +118,7 @@ public class JoinRender extends Renderer {
innerView.setTable(left);
filterView.setTable(SemanticNode.buildAs(Constants.JOIN_TABLE_OUT_PREFIX, innerView.build()));
if (!filterDimension.isEmpty()) {
for (String d : filterDimension) {
for (String d : getQueryDimension(filterDimension, queryAllDimension, whereFields)) {
if (nonAgg) {
filterView.getMeasure().add(SemanticNode.parse(d, scope));
} else {
@@ -183,6 +183,12 @@ public class JoinRender extends Renderer {
}
}
private Set<String> getQueryDimension(Set<String> filterDimension, Set<String> queryAllDimension,
Set<String> whereFields) {
return filterDimension.stream().filter(d -> queryAllDimension.contains(d) || whereFields.contains(d)).collect(
Collectors.toSet());
}
private boolean getMatchMetric(SemanticSchema schema, Set<String> sourceMeasure, String m,
List<String> queryMetrics) {