[improvement][headless]Introduce DerivedMetricConverter and optimize metric creation in S2VisitsDemo.

This commit is contained in:
jerryjzhang
2024-12-11 17:51:02 +08:00
parent 4062a13126
commit f97ac1da83
17 changed files with 205 additions and 291 deletions

View File

@@ -0,0 +1,115 @@
package com.tencent.supersonic.headless.core.translator.converter;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
import com.tencent.supersonic.headless.api.pojo.enums.MetricType;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.pojo.SqlQueryParam;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.OntologyQueryParam;
import com.tencent.supersonic.headless.core.utils.SqlGenerateUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.*;
@Component("DerivedMetricConverter")
@Slf4j
public class DerivedMetricConverter implements QueryConverter {
@Override
public boolean accept(QueryStatement queryStatement) {
return Objects.nonNull(queryStatement.getSqlQueryParam())
&& StringUtils.isNotBlank(queryStatement.getSqlQueryParam().getSql());
}
@Override
public void convert(QueryStatement queryStatement) throws Exception {
SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp();
SqlQueryParam sqlParam = queryStatement.getSqlQueryParam();
OntologyQueryParam ontologyParam = queryStatement.getOntologyQueryParam();
String sql = sqlParam.getSql();
Set<String> measures = new HashSet<>();
Map<String, String> replaces =
generateDerivedMetric(semanticSchemaResp, ontologyParam.getAggOption(),
ontologyParam.getMetrics(), ontologyParam.getDimensions(), measures);
if (!CollectionUtils.isEmpty(replaces)) {
// metricTable sql use measures replace metric
sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces);
ontologyParam.setAggOption(AggOption.NATIVE);
// metricTable use measures replace metric
if (!CollectionUtils.isEmpty(measures)) {
ontologyParam.getMetrics().addAll(measures);
}
}
sqlParam.setSql(sql);
queryStatement.setSql(queryStatement.getSqlQueryParam().getSql());
}
private Map<String, String> generateDerivedMetric(SemanticSchemaResp semanticSchemaResp,
AggOption aggOption, Set<String> metrics, Set<String> dimensions,
Set<String> measures) {
SqlGenerateUtils sqlGenerateUtils = ContextUtils.getBean(SqlGenerateUtils.class);
Map<String, String> result = new HashMap<>();
List<MetricSchemaResp> metricResps = semanticSchemaResp.getMetrics();
List<DimSchemaResp> dimensionResps = semanticSchemaResp.getDimensions();
// Check if any metric is derived
boolean hasDerivedMetrics =
metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType
.isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()));
if (!hasDerivedMetrics) {
return result;
}
log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics);
Set<String> allFields = new HashSet<>();
Map<String, Measure> allMeasures = new HashMap<>();
semanticSchemaResp.getModelResps().forEach(modelResp -> {
allFields.addAll(modelResp.getFieldList());
if (modelResp.getModelDetail().getMeasures() != null) {
modelResp.getModelDetail().getMeasures()
.forEach(measure -> allMeasures.put(measure.getBizName(), measure));
}
});
Set<String> derivedDimensions = new HashSet<>();
Set<String> derivedMetrics = new HashSet<>();
Map<String, String> visitedMetrics = new HashMap<>();
for (MetricResp metricResp : metricResps) {
if (metrics.contains(metricResp.getBizName())) {
boolean isDerived = MetricType.isDerived(metricResp.getMetricDefineType(),
metricResp.getMetricDefineByMeasureParams());
if (isDerived) {
String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields,
allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp),
metricResp.getMetricDefineType(), aggOption, visitedMetrics,
derivedMetrics, derivedDimensions);
result.put(metricResp.getBizName(), expr);
log.debug("derived metric {}->{}", metricResp.getBizName(), expr);
} else {
measures.add(metricResp.getBizName());
}
}
}
measures.addAll(derivedMetrics);
derivedDimensions.stream().filter(dimension -> !dimensions.contains(dimension))
.forEach(dimensions::add);
return result;
}
}

View File

@@ -7,11 +7,10 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
import com.tencent.supersonic.headless.api.pojo.enums.MetricType;
import com.tencent.supersonic.headless.api.pojo.response.*;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.pojo.SqlQueryParam;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.OntologyQueryParam;
@@ -20,7 +19,6 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
@@ -66,6 +64,7 @@ public class SqlQueryConverter implements QueryConverter {
ontologyQueryParam.getMetrics().addAll(metrics);
ontologyQueryParam.getDimensions().addAll(dimensions);
AggOption sqlQueryAggOption = getAggOption(sqlQueryParam.getSql(), metricSchemas);
// if sql query itself has aggregation, ontology query just returns detail
if (sqlQueryAggOption.equals(AggOption.AGGREGATION)) {
ontologyQueryParam.setAggOption(AggOption.NATIVE);
@@ -74,9 +73,6 @@ public class SqlQueryConverter implements QueryConverter {
}
ontologyQueryParam.setNativeQuery(!AggOption.isAgg(ontologyQueryParam.getAggOption()));
queryStatement.setOntologyQueryParam(ontologyQueryParam);
generateDerivedMetric(sqlGenerateUtils, queryStatement);
queryStatement.setSql(sqlQueryParam.getSql());
log.info("parse sqlQuery [{}] ", sqlQueryParam);
}
@@ -138,92 +134,6 @@ public class SqlQueryConverter implements QueryConverter {
.collect(Collectors.toList());
}
private void generateDerivedMetric(SqlGenerateUtils sqlGenerateUtils,
QueryStatement queryStatement) {
SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp();
SqlQueryParam sqlParam = queryStatement.getSqlQueryParam();
OntologyQueryParam ontologyParam = queryStatement.getOntologyQueryParam();
String sql = sqlParam.getSql();
Set<String> measures = new HashSet<>();
Map<String, String> replaces = generateDerivedMetric(sqlGenerateUtils, semanticSchemaResp,
ontologyParam.getAggOption(), ontologyParam.getMetrics(),
ontologyParam.getDimensions(), measures);
if (!CollectionUtils.isEmpty(replaces)) {
// metricTable sql use measures replace metric
sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces);
ontologyParam.setAggOption(AggOption.NATIVE);
// metricTable use measures replace metric
if (!CollectionUtils.isEmpty(measures)) {
ontologyParam.getMetrics().addAll(measures);
} else {
// empty measure , fill default
ontologyParam.getMetrics().add(sqlGenerateUtils.generateInternalMetricName(
getDefaultModel(semanticSchemaResp, ontologyParam.getDimensions())));
}
}
sqlParam.setSql(sql);
}
private Map<String, String> generateDerivedMetric(SqlGenerateUtils sqlGenerateUtils,
SemanticSchemaResp semanticSchemaResp, AggOption aggOption, Set<String> metrics,
Set<String> dimensions, Set<String> measures) {
Map<String, String> result = new HashMap<>();
List<MetricSchemaResp> metricResps = semanticSchemaResp.getMetrics();
List<DimSchemaResp> dimensionResps = semanticSchemaResp.getDimensions();
// Check if any metric is derived
boolean hasDerivedMetrics =
metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType
.isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()));
if (!hasDerivedMetrics) {
return result;
}
log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics);
Set<String> allFields = new HashSet<>();
Map<String, Measure> allMeasures = new HashMap<>();
semanticSchemaResp.getModelResps().forEach(modelResp -> {
allFields.addAll(modelResp.getFieldList());
if (modelResp.getModelDetail().getMeasures() != null) {
modelResp.getModelDetail().getMeasures()
.forEach(measure -> allMeasures.put(measure.getBizName(), measure));
}
});
Set<String> derivedDimensions = new HashSet<>();
Set<String> derivedMetrics = new HashSet<>();
Map<String, String> visitedMetrics = new HashMap<>();
for (MetricResp metricResp : metricResps) {
if (metrics.contains(metricResp.getBizName())) {
boolean isDerived = MetricType.isDerived(metricResp.getMetricDefineType(),
metricResp.getMetricDefineByMeasureParams());
if (isDerived) {
String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields,
allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp),
metricResp.getMetricDefineType(), aggOption, visitedMetrics,
derivedMetrics, derivedDimensions);
result.put(metricResp.getBizName(), expr);
log.debug("derived metric {}->{}", metricResp.getBizName(), expr);
} else {
measures.add(metricResp.getBizName());
}
}
}
measures.addAll(derivedMetrics);
derivedDimensions.stream().filter(dimension -> !dimensions.contains(dimension))
.forEach(dimensions::add);
return result;
}
private void convertNameToBizName(QueryStatement queryStatement) {
SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp();
Map<String, String> fieldNameToBizNameMap = getFieldNameToBizNameMap(semanticSchemaResp);
@@ -276,18 +186,4 @@ public class SqlQueryConverter implements QueryConverter {
return elements.stream();
}
private String getDefaultModel(SemanticSchemaResp semanticSchemaResp, Set<String> dimensions) {
if (!CollectionUtils.isEmpty(dimensions)) {
Map<String, Long> modelMatchCnt = new HashMap<>();
for (ModelResp modelResp : semanticSchemaResp.getModelResps()) {
modelMatchCnt.put(modelResp.getBizName(), modelResp.getModelDetail().getDimensions()
.stream().filter(d -> dimensions.contains(d.getBizName())).count());
}
return modelMatchCnt.entrySet().stream()
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder()))
.map(Map.Entry::getKey).findFirst().orElse("");
}
return semanticSchemaResp.getModelResps().get(0).getBizName();
}
}

View File

@@ -140,8 +140,16 @@ public class DataModelNode extends SemanticNode {
Set<String> schemaMetricName =
ontology.getMetrics().stream().map(Metric::getName).collect(Collectors.toSet());
ontology.getMetrics().stream().filter(m -> queryParam.getMetrics().contains(m.getName()))
.forEach(m -> m.getMetricTypeParams().getMeasures()
.forEach(mm -> queryMeasures.add(mm.getName())));
.forEach(m -> {
if (!CollectionUtils.isEmpty(m.getMetricTypeParams().getMeasures())) {
m.getMetricTypeParams().getMeasures()
.forEach(mm -> queryMeasures.add(mm.getName()));
}
if (!CollectionUtils.isEmpty(m.getMetricTypeParams().getFields())) {
m.getMetricTypeParams().getFields()
.forEach(mm -> queryMeasures.add(mm.getName()));
}
});
queryParam.getMetrics().stream().filter(m -> !schemaMetricName.contains(m))
.forEach(queryMeasures::add);
}