From be158a1776b1809be54bd8e9b872e11987066027 Mon Sep 17 00:00:00 2001 From: jipeli <54889677+jipeli@users.noreply.github.com> Date: Mon, 22 Jan 2024 19:13:55 +0800 Subject: [PATCH] [improvement][headless] fix derived metric aggOption error (#679) --- .../QueryExpressionReplaceVisitor.java | 9 ++++ .../parser/calcite/schema/SchemaBuilder.java | 3 +- .../converter/CalculateAggConverter.java | 1 + .../headless/core/utils/SqlGenerateUtils.java | 50 ++++++++++--------- .../server/utils/QueryReqConverter.java | 14 ++++-- 5 files changed, 47 insertions(+), 30 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/QueryExpressionReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/QueryExpressionReplaceVisitor.java index a6afc3dbc..b9547fd25 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/QueryExpressionReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/QueryExpressionReplaceVisitor.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.common.util.jsqlparser; import java.util.Map; import java.util.Objects; +import net.sf.jsqlparser.expression.Alias; import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; @@ -46,20 +47,28 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter { Expression expression = selectExpressionItem.getExpression(); String toReplace = ""; + String columnName = ""; if (expression instanceof Function) { Function leftFunc = (Function) expression; if (leftFunc.getParameters().getExpressions().get(0) instanceof Column) { + Column column = (Column) leftFunc.getParameters().getExpressions().get(0); + columnName = column.getColumnName(); toReplace = getReplaceExpr(leftFunc, fieldExprMap); } } if (expression instanceof Column) { + Column column = (Column) expression; + columnName = column.getColumnName(); toReplace = getReplaceExpr((Column) expression, fieldExprMap); } if (!toReplace.isEmpty()) { Expression toReplaceExpr = getExpression(toReplace); if (Objects.nonNull(toReplaceExpr)) { selectExpressionItem.setExpression(toReplaceExpr); + if (Objects.isNull(selectExpressionItem.getAlias())) { + selectExpressionItem.setAlias(new Alias(columnName, true)); + } } } //selectExpressionItem.getExpression().accept(this); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/calcite/schema/SchemaBuilder.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/calcite/schema/SchemaBuilder.java index b7146515c..01104a91a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/calcite/schema/SchemaBuilder.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/calcite/schema/SchemaBuilder.java @@ -4,7 +4,6 @@ package com.tencent.supersonic.headless.core.parser.calcite.schema; import com.tencent.supersonic.headless.api.enums.EngineType; import com.tencent.supersonic.headless.core.parser.calcite.Configuration; import com.tencent.supersonic.headless.core.parser.calcite.sql.S2SQLSqlValidatorImpl; - import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -74,7 +73,7 @@ public class SchemaBuilder { builder.addField(dim, SqlTypeName.VARCHAR); } for (String metric : metrics) { - builder.addField(metric, SqlTypeName.BIGINT); + builder.addField(metric, SqlTypeName.ANY); } DataSourceTable srcTable = builder .withRowCount(1) diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/converter/CalculateAggConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/converter/CalculateAggConverter.java index edcb95084..78f33c4b1 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/converter/CalculateAggConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/converter/CalculateAggConverter.java @@ -135,6 +135,7 @@ public class CalculateAggConverter implements HeadlessConverter { throws Exception { QueryStructReq queryStructReq = queryStatement.getQueryStructReq(); check(queryStructReq); + queryStatement.setEnableOptimize(false); ParseSqlReq sqlCommand = new ParseSqlReq(); sqlCommand.setRootPath(queryStructReq.getModelIdStr()); String metricTableName = "v_metric_tb_tmp"; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java index a0153cb3d..f67c91e4a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java @@ -1,5 +1,12 @@ package com.tencent.supersonic.headless.core.utils; +import static com.tencent.supersonic.common.pojo.Constants.DAY; +import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT; +import static com.tencent.supersonic.common.pojo.Constants.JOIN_UNDERLINE; +import static com.tencent.supersonic.common.pojo.Constants.MONTH; +import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE; +import static com.tencent.supersonic.common.pojo.Constants.WEEK; + import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.ItemDateResp; @@ -10,21 +17,13 @@ import com.tencent.supersonic.common.util.SqlFilterUtils; import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.headless.api.enums.AggOption; import com.tencent.supersonic.headless.api.enums.EngineType; import com.tencent.supersonic.headless.api.enums.MetricDefineType; import com.tencent.supersonic.headless.api.pojo.Measure; import com.tencent.supersonic.headless.api.request.QueryStructReq; import com.tencent.supersonic.headless.api.response.DimensionResp; import com.tencent.supersonic.headless.api.response.MetricResp; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Triple; -import org.apache.logging.log4j.util.Strings; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; -import org.springframework.util.CollectionUtils; - import java.time.LocalDate; import java.time.format.DateTimeFormatter; import java.util.Collections; @@ -36,13 +35,14 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; - -import static com.tencent.supersonic.common.pojo.Constants.DAY; -import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT; -import static com.tencent.supersonic.common.pojo.Constants.JOIN_UNDERLINE; -import static com.tencent.supersonic.common.pojo.Constants.MONTH; -import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE; -import static com.tencent.supersonic.common.pojo.Constants.WEEK; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Triple; +import org.apache.logging.log4j.util.Strings; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; /** * tools functions to analyze queryStructReq @@ -271,7 +271,8 @@ public class SqlGenerateUtils { public String generateDerivedMetric(final List metricResps, final Set allFields, final Map allMeasures, final List dimensionResps, - final String expression, final MetricDefineType metricDefineType, Set visitedMetric, + final String expression, final MetricDefineType metricDefineType, AggOption aggOption, + Set visitedMetric, Set measures, Set dimensions) { Set fields = SqlParserSelectHelper.getColumnFromExpr(expression); @@ -289,14 +290,14 @@ public class SqlGenerateUtils { replace.put(field, generateDerivedMetric(metricResps, allFields, allMeasures, dimensionResps, getExpr(metricItem.get()), metricItem.get().getMetricDefineType(), - visitedMetric, measures, dimensions)); + aggOption, visitedMetric, measures, dimensions)); visitedMetric.add(field); } break; case MEASURE: if (allMeasures.containsKey(field)) { measures.add(field); - replace.put(field, getExpr(allMeasures.get(field))); + replace.put(field, getExpr(allMeasures.get(field), aggOption)); } break; case FIELD: @@ -324,12 +325,15 @@ public class SqlGenerateUtils { return expression; } - public String getExpr(Measure measure) { + public String getExpr(Measure measure, AggOption aggOption) { if (AggOperatorEnum.COUNT_DISTINCT.getOperator().equalsIgnoreCase(measure.getAgg())) { - return AggOperatorEnum.COUNT.getOperator() + " ( " + AggOperatorEnum.DISTINCT + " " + measure.getBizName() - + " ) "; + return aggOption.equals(AggOption.NATIVE) ? measure.getBizName() + : AggOperatorEnum.COUNT.getOperator() + " ( " + AggOperatorEnum.DISTINCT + " " + + measure.getBizName() + + " ) "; } - return measure.getAgg() + " ( " + measure.getBizName() + " ) "; + return aggOption.equals(AggOption.NATIVE) ? measure.getBizName() + : measure.getAgg() + " ( " + measure.getBizName() + " ) "; } public String getExpr(MetricResp metricResp) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java index 9c0bf8823..72867f716 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.server.utils; + import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; @@ -121,7 +122,7 @@ public class QueryReqConverter { result.setWithAlias(false); } //5. do deriveMetric - generateDerivedMetric(querySQLReq.getModelIds(), modelSchemaResps, result); + generateDerivedMetric(querySQLReq.getModelIds(), modelSchemaResps, aggOption, result); //6.physicalSql by ParseSqlReq queryStructReq.setDateInfo(queryStructUtils.getDateConfBySql(querySQLReq.getSql())); queryStructReq.setModelIds(new HashSet<>(querySQLReq.getModelIds())); @@ -242,13 +243,14 @@ public class QueryReqConverter { return queryType; } - private void generateDerivedMetric(List modelIds, List modelSchemaResps, + private void generateDerivedMetric(List modelIds, List modelSchemaResps, AggOption aggOption, ParseSqlReq parseSqlReq) { String sql = parseSqlReq.getSql(); for (MetricTable metricTable : parseSqlReq.getTables()) { List measures = new ArrayList<>(); Map replaces = new HashMap<>(); - generateDerivedMetric(modelIds, modelSchemaResps, metricTable.getMetrics(), metricTable.getDimensions(), + generateDerivedMetric(modelIds, modelSchemaResps, aggOption, metricTable.getMetrics(), + metricTable.getDimensions(), measures, replaces); if (!CollectionUtils.isEmpty(replaces)) { // metricTable sql use measures replace metric @@ -263,7 +265,7 @@ public class QueryReqConverter { parseSqlReq.setSql(sql); } - private void generateDerivedMetric(List modelIds, List modelSchemaResps, + private void generateDerivedMetric(List modelIds, List modelSchemaResps, AggOption aggOption, List metrics, List dimensions, List measures, Map replaces) { MetaFilter metaFilter = new MetaFilter(); @@ -276,6 +278,7 @@ public class QueryReqConverter { m.getMetricDefineByMeasureParams()))) { return; } + log.info("begin to generateDerivedMetric {} [{}]", aggOption, metrics); Set allFields = new HashSet<>(); Map allMeasures = new HashMap<>(); modelSchemaResps.stream().forEach(modelSchemaResp -> { @@ -296,7 +299,8 @@ public class QueryReqConverter { metricResp.getMetricDefineByMeasureParams())) { String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields, allMeasures, dimensionResps, - sqlGenerateUtils.getExpr(metricResp), metricResp.getMetricDefineType(), visitedMetric, + sqlGenerateUtils.getExpr(metricResp), metricResp.getMetricDefineType(), aggOption, + visitedMetric, deriveMetric, deriveDimension); replaces.put(metricResp.getBizName(), expr); log.info("derived metric {}->{}", metricResp.getBizName(), expr);