[improvement][headless] fix derived metric aggOption error (#679)

This commit is contained in:
jipeli
2024-01-22 19:13:55 +08:00
committed by GitHub
parent c12f5d23f0
commit be158a1776
5 changed files with 47 additions and 30 deletions

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.common.util.jsqlparser;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
@@ -46,20 +47,28 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter {
Expression expression = selectExpressionItem.getExpression(); Expression expression = selectExpressionItem.getExpression();
String toReplace = ""; String toReplace = "";
String columnName = "";
if (expression instanceof Function) { if (expression instanceof Function) {
Function leftFunc = (Function) expression; Function leftFunc = (Function) expression;
if (leftFunc.getParameters().getExpressions().get(0) instanceof Column) { if (leftFunc.getParameters().getExpressions().get(0) instanceof Column) {
Column column = (Column) leftFunc.getParameters().getExpressions().get(0);
columnName = column.getColumnName();
toReplace = getReplaceExpr(leftFunc, fieldExprMap); toReplace = getReplaceExpr(leftFunc, fieldExprMap);
} }
} }
if (expression instanceof Column) { if (expression instanceof Column) {
Column column = (Column) expression;
columnName = column.getColumnName();
toReplace = getReplaceExpr((Column) expression, fieldExprMap); toReplace = getReplaceExpr((Column) expression, fieldExprMap);
} }
if (!toReplace.isEmpty()) { if (!toReplace.isEmpty()) {
Expression toReplaceExpr = getExpression(toReplace); Expression toReplaceExpr = getExpression(toReplace);
if (Objects.nonNull(toReplaceExpr)) { if (Objects.nonNull(toReplaceExpr)) {
selectExpressionItem.setExpression(toReplaceExpr); selectExpressionItem.setExpression(toReplaceExpr);
if (Objects.isNull(selectExpressionItem.getAlias())) {
selectExpressionItem.setAlias(new Alias(columnName, true));
}
} }
} }
//selectExpressionItem.getExpression().accept(this); //selectExpressionItem.getExpression().accept(this);

View File

@@ -4,7 +4,6 @@ package com.tencent.supersonic.headless.core.parser.calcite.schema;
import com.tencent.supersonic.headless.api.enums.EngineType; import com.tencent.supersonic.headless.api.enums.EngineType;
import com.tencent.supersonic.headless.core.parser.calcite.Configuration; import com.tencent.supersonic.headless.core.parser.calcite.Configuration;
import com.tencent.supersonic.headless.core.parser.calcite.sql.S2SQLSqlValidatorImpl; import com.tencent.supersonic.headless.core.parser.calcite.sql.S2SQLSqlValidatorImpl;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@@ -74,7 +73,7 @@ public class SchemaBuilder {
builder.addField(dim, SqlTypeName.VARCHAR); builder.addField(dim, SqlTypeName.VARCHAR);
} }
for (String metric : metrics) { for (String metric : metrics) {
builder.addField(metric, SqlTypeName.BIGINT); builder.addField(metric, SqlTypeName.ANY);
} }
DataSourceTable srcTable = builder DataSourceTable srcTable = builder
.withRowCount(1) .withRowCount(1)

View File

@@ -135,6 +135,7 @@ public class CalculateAggConverter implements HeadlessConverter {
throws Exception { throws Exception {
QueryStructReq queryStructReq = queryStatement.getQueryStructReq(); QueryStructReq queryStructReq = queryStatement.getQueryStructReq();
check(queryStructReq); check(queryStructReq);
queryStatement.setEnableOptimize(false);
ParseSqlReq sqlCommand = new ParseSqlReq(); ParseSqlReq sqlCommand = new ParseSqlReq();
sqlCommand.setRootPath(queryStructReq.getModelIdStr()); sqlCommand.setRootPath(queryStructReq.getModelIdStr());
String metricTableName = "v_metric_tb_tmp"; String metricTableName = "v_metric_tb_tmp";

View File

@@ -1,5 +1,12 @@
package com.tencent.supersonic.headless.core.utils; package com.tencent.supersonic.headless.core.utils;
import static com.tencent.supersonic.common.pojo.Constants.DAY;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.JOIN_UNDERLINE;
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE;
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.ItemDateResp; import com.tencent.supersonic.common.pojo.ItemDateResp;
@@ -10,21 +17,13 @@ import com.tencent.supersonic.common.util.SqlFilterUtils;
import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.headless.api.enums.AggOption;
import com.tencent.supersonic.headless.api.enums.EngineType; import com.tencent.supersonic.headless.api.enums.EngineType;
import com.tencent.supersonic.headless.api.enums.MetricDefineType; import com.tencent.supersonic.headless.api.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.Measure; import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.request.QueryStructReq; import com.tencent.supersonic.headless.api.request.QueryStructReq;
import com.tencent.supersonic.headless.api.response.DimensionResp; import com.tencent.supersonic.headless.api.response.DimensionResp;
import com.tencent.supersonic.headless.api.response.MetricResp; import com.tencent.supersonic.headless.api.response.MetricResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.util.Collections; import java.util.Collections;
@@ -36,13 +35,14 @@ import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.common.pojo.Constants.DAY; import org.apache.commons.lang3.StringUtils;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT; import org.apache.commons.lang3.tuple.ImmutablePair;
import static com.tencent.supersonic.common.pojo.Constants.JOIN_UNDERLINE; import org.apache.commons.lang3.tuple.Triple;
import static com.tencent.supersonic.common.pojo.Constants.MONTH; import org.apache.logging.log4j.util.Strings;
import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE; import org.springframework.beans.factory.annotation.Value;
import static com.tencent.supersonic.common.pojo.Constants.WEEK; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
/** /**
* tools functions to analyze queryStructReq * tools functions to analyze queryStructReq
@@ -271,7 +271,8 @@ public class SqlGenerateUtils {
public String generateDerivedMetric(final List<MetricResp> metricResps, final Set<String> allFields, public String generateDerivedMetric(final List<MetricResp> metricResps, final Set<String> allFields,
final Map<String, Measure> allMeasures, final List<DimensionResp> dimensionResps, final Map<String, Measure> allMeasures, final List<DimensionResp> dimensionResps,
final String expression, final MetricDefineType metricDefineType, Set<String> visitedMetric, final String expression, final MetricDefineType metricDefineType, AggOption aggOption,
Set<String> visitedMetric,
Set<String> measures, Set<String> measures,
Set<String> dimensions) { Set<String> dimensions) {
Set<String> fields = SqlParserSelectHelper.getColumnFromExpr(expression); Set<String> fields = SqlParserSelectHelper.getColumnFromExpr(expression);
@@ -289,14 +290,14 @@ public class SqlGenerateUtils {
replace.put(field, replace.put(field,
generateDerivedMetric(metricResps, allFields, allMeasures, dimensionResps, generateDerivedMetric(metricResps, allFields, allMeasures, dimensionResps,
getExpr(metricItem.get()), metricItem.get().getMetricDefineType(), getExpr(metricItem.get()), metricItem.get().getMetricDefineType(),
visitedMetric, measures, dimensions)); aggOption, visitedMetric, measures, dimensions));
visitedMetric.add(field); visitedMetric.add(field);
} }
break; break;
case MEASURE: case MEASURE:
if (allMeasures.containsKey(field)) { if (allMeasures.containsKey(field)) {
measures.add(field); measures.add(field);
replace.put(field, getExpr(allMeasures.get(field))); replace.put(field, getExpr(allMeasures.get(field), aggOption));
} }
break; break;
case FIELD: case FIELD:
@@ -324,12 +325,15 @@ public class SqlGenerateUtils {
return expression; return expression;
} }
public String getExpr(Measure measure) { public String getExpr(Measure measure, AggOption aggOption) {
if (AggOperatorEnum.COUNT_DISTINCT.getOperator().equalsIgnoreCase(measure.getAgg())) { if (AggOperatorEnum.COUNT_DISTINCT.getOperator().equalsIgnoreCase(measure.getAgg())) {
return AggOperatorEnum.COUNT.getOperator() + " ( " + AggOperatorEnum.DISTINCT + " " + measure.getBizName() return aggOption.equals(AggOption.NATIVE) ? measure.getBizName()
: AggOperatorEnum.COUNT.getOperator() + " ( " + AggOperatorEnum.DISTINCT + " "
+ measure.getBizName()
+ " ) "; + " ) ";
} }
return measure.getAgg() + " ( " + measure.getBizName() + " ) "; return aggOption.equals(AggOption.NATIVE) ? measure.getBizName()
: measure.getAgg() + " ( " + measure.getBizName() + " ) ";
} }
public String getExpr(MetricResp metricResp) { public String getExpr(MetricResp metricResp) {

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.server.utils; package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
@@ -121,7 +122,7 @@ public class QueryReqConverter {
result.setWithAlias(false); result.setWithAlias(false);
} }
//5. do deriveMetric //5. do deriveMetric
generateDerivedMetric(querySQLReq.getModelIds(), modelSchemaResps, result); generateDerivedMetric(querySQLReq.getModelIds(), modelSchemaResps, aggOption, result);
//6.physicalSql by ParseSqlReq //6.physicalSql by ParseSqlReq
queryStructReq.setDateInfo(queryStructUtils.getDateConfBySql(querySQLReq.getSql())); queryStructReq.setDateInfo(queryStructUtils.getDateConfBySql(querySQLReq.getSql()));
queryStructReq.setModelIds(new HashSet<>(querySQLReq.getModelIds())); queryStructReq.setModelIds(new HashSet<>(querySQLReq.getModelIds()));
@@ -242,13 +243,14 @@ public class QueryReqConverter {
return queryType; return queryType;
} }
private void generateDerivedMetric(List<Long> modelIds, List<ModelSchemaResp> modelSchemaResps, private void generateDerivedMetric(List<Long> modelIds, List<ModelSchemaResp> modelSchemaResps, AggOption aggOption,
ParseSqlReq parseSqlReq) { ParseSqlReq parseSqlReq) {
String sql = parseSqlReq.getSql(); String sql = parseSqlReq.getSql();
for (MetricTable metricTable : parseSqlReq.getTables()) { for (MetricTable metricTable : parseSqlReq.getTables()) {
List<String> measures = new ArrayList<>(); List<String> measures = new ArrayList<>();
Map<String, String> replaces = new HashMap<>(); Map<String, String> replaces = new HashMap<>();
generateDerivedMetric(modelIds, modelSchemaResps, metricTable.getMetrics(), metricTable.getDimensions(), generateDerivedMetric(modelIds, modelSchemaResps, aggOption, metricTable.getMetrics(),
metricTable.getDimensions(),
measures, replaces); measures, replaces);
if (!CollectionUtils.isEmpty(replaces)) { if (!CollectionUtils.isEmpty(replaces)) {
// metricTable sql use measures replace metric // metricTable sql use measures replace metric
@@ -263,7 +265,7 @@ public class QueryReqConverter {
parseSqlReq.setSql(sql); parseSqlReq.setSql(sql);
} }
private void generateDerivedMetric(List<Long> modelIds, List<ModelSchemaResp> modelSchemaResps, private void generateDerivedMetric(List<Long> modelIds, List<ModelSchemaResp> modelSchemaResps, AggOption aggOption,
List<String> metrics, List<String> dimensions, List<String> metrics, List<String> dimensions,
List<String> measures, Map<String, String> replaces) { List<String> measures, Map<String, String> replaces) {
MetaFilter metaFilter = new MetaFilter(); MetaFilter metaFilter = new MetaFilter();
@@ -276,6 +278,7 @@ public class QueryReqConverter {
m.getMetricDefineByMeasureParams()))) { m.getMetricDefineByMeasureParams()))) {
return; return;
} }
log.info("begin to generateDerivedMetric {} [{}]", aggOption, metrics);
Set<String> allFields = new HashSet<>(); Set<String> allFields = new HashSet<>();
Map<String, Measure> allMeasures = new HashMap<>(); Map<String, Measure> allMeasures = new HashMap<>();
modelSchemaResps.stream().forEach(modelSchemaResp -> { modelSchemaResps.stream().forEach(modelSchemaResp -> {
@@ -296,7 +299,8 @@ public class QueryReqConverter {
metricResp.getMetricDefineByMeasureParams())) { metricResp.getMetricDefineByMeasureParams())) {
String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields, allMeasures, String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields, allMeasures,
dimensionResps, dimensionResps,
sqlGenerateUtils.getExpr(metricResp), metricResp.getMetricDefineType(), visitedMetric, sqlGenerateUtils.getExpr(metricResp), metricResp.getMetricDefineType(), aggOption,
visitedMetric,
deriveMetric, deriveDimension); deriveMetric, deriveDimension);
replaces.put(metricResp.getBizName(), expr); replaces.put(metricResp.getBizName(), expr);
log.info("derived metric {}->{}", metricResp.getBizName(), expr); log.info("derived metric {}->{}", metricResp.getBizName(), expr);