diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java index 2ee2bd317..870cce668 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java @@ -10,13 +10,15 @@ public class FieldReplaceVisitor extends ExpressionVisitorAdapter { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); private Map fieldToBizName; + private boolean exactReplace; - public FieldReplaceVisitor(Map fieldToBizName) { + public FieldReplaceVisitor(Map fieldToBizName, boolean exactReplace) { this.fieldToBizName = fieldToBizName; + this.exactReplace = exactReplace; } @Override public void visit(Column column) { - parseVisitorHelper.replaceColumn(column, fieldToBizName); + parseVisitorHelper.replaceColumn(column, fieldToBizName, exactReplace); } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionAliasReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionAliasReplaceVisitor.java new file mode 100644 index 000000000..ccc8a6821 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionAliasReplaceVisitor.java @@ -0,0 +1,28 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.statement.select.SelectExpressionItem; +import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter; + +public class FunctionAliasReplaceVisitor extends SelectItemVisitorAdapter { + + private Map aliasToActualExpression = new HashMap<>(); + + @Override + public void visit(SelectExpressionItem selectExpressionItem) { + if (selectExpressionItem.getExpression() instanceof Function) { + Function function = (Function) selectExpressionItem.getExpression(); + if (Objects.nonNull(selectExpressionItem.getAlias())) { + aliasToActualExpression.put(selectExpressionItem.getAlias().getName(), function.toString()); + selectExpressionItem.setAlias(null); + } + } + } + + public Map getAliasToActualExpression() { + return aliasToActualExpression; + } +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java index 2971a64a6..e60c465d9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java @@ -18,10 +18,11 @@ public class GroupByReplaceVisitor implements GroupByVisitor { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); private Map fieldToBizName; + private boolean exactReplace; - - public GroupByReplaceVisitor(Map fieldToBizName) { + public GroupByReplaceVisitor(Map fieldToBizName, boolean exactReplace) { this.fieldToBizName = fieldToBizName; + this.exactReplace = exactReplace; } public void visit(GroupByElement groupByElement) { @@ -32,7 +33,8 @@ public class GroupByReplaceVisitor implements GroupByVisitor { for (int i = 0; i < groupByExpressions.size(); i++) { Expression expression = groupByExpressions.get(i); - String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName); + String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName, + exactReplace); if (StringUtils.isNotEmpty(replaceColumn)) { if (expression instanceof Column) { groupByExpressions.set(i, new Column(replaceColumn)); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByReplaceVisitor.java index a3a4ccda6..858400f86 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByReplaceVisitor.java @@ -12,23 +12,25 @@ public class OrderByReplaceVisitor extends OrderByVisitorAdapter { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); private Map fieldToBizName; + private boolean exactReplace; - public OrderByReplaceVisitor(Map fieldToBizName) { + public OrderByReplaceVisitor(Map fieldToBizName, boolean exactReplace) { this.fieldToBizName = fieldToBizName; + this.exactReplace = exactReplace; } @Override public void visit(OrderByElement orderBy) { Expression expression = orderBy.getExpression(); if (expression instanceof Column) { - parseVisitorHelper.replaceColumn((Column) expression, fieldToBizName); + parseVisitorHelper.replaceColumn((Column) expression, fieldToBizName, exactReplace); } if (expression instanceof Function) { Function function = (Function) expression; List expressions = function.getParameters().getExpressions(); for (Expression column : expressions) { if (column instanceof Column) { - parseVisitorHelper.replaceColumn((Column) column, fieldToBizName); + parseVisitorHelper.replaceColumn((Column) column, fieldToBizName, exactReplace); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java index fde04277d..b526d6d24 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java @@ -11,27 +11,32 @@ import org.apache.commons.lang3.StringUtils; @Slf4j public class ParseVisitorHelper { - public void replaceColumn(Column column, Map fieldToBizName) { + public void replaceColumn(Column column, Map fieldToBizName, boolean exactReplace) { String columnName = column.getColumnName(); - column.setColumnName(getReplaceColumn(columnName, fieldToBizName)); + String replaceColumn = getReplaceColumn(columnName, fieldToBizName, exactReplace); + if (StringUtils.isNotBlank(replaceColumn)) { + column.setColumnName(replaceColumn); + } } - public String getReplaceColumn(String columnName, Map fieldToBizName) { + public String getReplaceColumn(String columnName, Map fieldToBizName, boolean exactReplace) { String fieldBizName = fieldToBizName.get(columnName); - if (StringUtils.isNotEmpty(fieldBizName)) { + if (StringUtils.isNotBlank(fieldBizName)) { return fieldBizName; - } else { - Optional> first = fieldToBizName.entrySet().stream().sorted((k1, k2) -> { - String k1FieldNameDb = k1.getKey(); - String k2FieldNameDb = k2.getKey(); - Double k1Similarity = getSimilarity(columnName, k1FieldNameDb); - Double k2Similarity = getSimilarity(columnName, k2FieldNameDb); - return k2Similarity.compareTo(k1Similarity); - }).collect(Collectors.toList()).stream().findFirst(); + } + if (exactReplace) { + return null; + } + Optional> first = fieldToBizName.entrySet().stream().sorted((k1, k2) -> { + String k1FieldNameDb = k1.getKey(); + String k2FieldNameDb = k2.getKey(); + Double k1Similarity = getSimilarity(columnName, k1FieldNameDb); + Double k2Similarity = getSimilarity(columnName, k2FieldNameDb); + return k2Similarity.compareTo(k1Similarity); + }).collect(Collectors.toList()).stream().findFirst(); - if (first.isPresent()) { - return first.get().getValue(); - } + if (first.isPresent()) { + return first.get().getValue(); } return columnName; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index e853864cf..5011182e2 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -58,8 +58,11 @@ public class SqlParserUpdateHelper { return selectStatement.toString(); } - public static String replaceFields(String sql, Map fieldToBizName) { + return replaceFields(sql, fieldToBizName, false); + } + + public static String replaceFields(String sql, Map fieldToBizName, boolean exactReplace) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); if (!(selectBody instanceof PlainSelect)) { @@ -68,7 +71,7 @@ public class SqlParserUpdateHelper { PlainSelect plainSelect = (PlainSelect) selectBody; //1. replace where fields Expression where = plainSelect.getWhere(); - FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldToBizName); + FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldToBizName, exactReplace); if (Objects.nonNull(where)) { where.accept(visitor); } @@ -82,14 +85,14 @@ public class SqlParserUpdateHelper { List orderByElements = plainSelect.getOrderByElements(); if (!CollectionUtils.isEmpty(orderByElements)) { for (OrderByElement orderByElement : orderByElements) { - orderByElement.accept(new OrderByReplaceVisitor(fieldToBizName)); + orderByElement.accept(new OrderByReplaceVisitor(fieldToBizName, exactReplace)); } } //4. replace group by fields GroupByElement groupByElement = plainSelect.getGroupBy(); if (Objects.nonNull(groupByElement)) { - groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName)); + groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName, exactReplace)); } return selectStatement.toString(); } @@ -178,6 +181,24 @@ public class SqlParserUpdateHelper { } + public static String replaceAlias(String sql) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + FunctionAliasReplaceVisitor visitor = new FunctionAliasReplaceVisitor(); + for (SelectItem selectItem : plainSelect.getSelectItems()) { + selectItem.accept(visitor); + } + Map aliasToActualExpression = visitor.getAliasToActualExpression(); + if (Objects.nonNull(aliasToActualExpression) && !aliasToActualExpression.isEmpty()) { + return replaceFields(selectStatement.toString(), aliasToActualExpression, true); + } + return selectStatement.toString(); + } + public static String addWhere(String sql, String column, Object value) { if (StringUtils.isEmpty(column) || Objects.isNull(value)) { return sql; diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index b9d46a879..c49e8ea3e 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -217,6 +217,19 @@ class SqlParserUpdateHelperTest { replaceSql); } + @Test + void replaceAlias() { + String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; + String replaceSql = SqlParserUpdateHelper.replaceAlias(sql); + System.out.println(replaceSql); + Assert.assertEquals( + "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", + replaceSql); + + } + private Map initParams() { Map fieldToBizName = new HashMap<>(); fieldToBizName.put("部门", "department"); diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index 1c1b866e2..0e599ccd4 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -32,6 +32,7 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.corrector.DateFieldCorrector, \ + com.tencent.supersonic.chat.corrector.FunctionAliasReplaceVisitor, \ com.tencent.supersonic.chat.corrector.FieldNameCorrector, \ com.tencent.supersonic.chat.corrector.FieldCorrector, \ com.tencent.supersonic.chat.corrector.FunctionCorrector, \ diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 14609d5fd..9efea619b 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -32,6 +32,7 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.corrector.DateFieldCorrector, \ + com.tencent.supersonic.chat.corrector.FunctionAliasReplaceVisitor, \ com.tencent.supersonic.chat.corrector.FieldNameCorrector, \ com.tencent.supersonic.chat.corrector.FieldCorrector, \ com.tencent.supersonic.chat.corrector.FunctionCorrector, \ diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/render/JoinRender.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/render/JoinRender.java index 44717076d..a83f7d70e 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/render/JoinRender.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/render/JoinRender.java @@ -1,18 +1,18 @@ package com.tencent.supersonic.semantic.query.parser.calcite.sql.render; import com.tencent.supersonic.semantic.api.query.request.MetricReq; -import com.tencent.supersonic.semantic.query.parser.calcite.sql.Renderer; -import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.AggFunctionNode; -import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.DataSourceNode; -import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.FilterNode; -import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.MetricNode; -import com.tencent.supersonic.semantic.query.parser.calcite.sql.TableView; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Constants; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.DataSource; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Dimension; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Identify; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Metric; import com.tencent.supersonic.semantic.query.parser.calcite.schema.SemanticSchema; +import com.tencent.supersonic.semantic.query.parser.calcite.sql.Renderer; +import com.tencent.supersonic.semantic.query.parser.calcite.sql.TableView; +import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.AggFunctionNode; +import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.DataSourceNode; +import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.FilterNode; +import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.MetricNode; import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.SemanticNode; import java.util.ArrayList; import java.util.Arrays; @@ -118,7 +118,7 @@ public class JoinRender extends Renderer { innerView.setTable(left); filterView.setTable(SemanticNode.buildAs(Constants.JOIN_TABLE_OUT_PREFIX, innerView.build())); if (!filterDimension.isEmpty()) { - for (String d : filterDimension) { + for (String d : getQueryDimension(filterDimension, queryAllDimension, whereFields)) { if (nonAgg) { filterView.getMeasure().add(SemanticNode.parse(d, scope)); } else { @@ -183,6 +183,12 @@ public class JoinRender extends Renderer { } } + private Set getQueryDimension(Set filterDimension, Set queryAllDimension, + Set whereFields) { + return filterDimension.stream().filter(d -> queryAllDimension.contains(d) || whereFields.contains(d)).collect( + Collectors.toSet()); + } + private boolean getMatchMetric(SemanticSchema schema, Set sourceMeasure, String m, List queryMetrics) {