(improvement)(Headless) Obtain metric default agg (#984) (#985)

This commit is contained in:
jipeli
2024-05-13 12:05:02 +08:00
committed by GitHub
parent 210591e28f
commit 947a01e8ba
2 changed files with 64 additions and 25 deletions

View File

@@ -1,11 +1,14 @@
package com.tencent.supersonic.common.util.jsqlparser; package com.tencent.supersonic.common.util.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.Function;
@@ -23,6 +26,8 @@ import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class SqlSelectFunctionHelper { public class SqlSelectFunctionHelper {
public static List<String> aggregateFunctionName = Arrays.asList("SUM", "COUNT", "MAX", "MIN", "AVG");
public static boolean hasAggregateFunction(String sql) { public static boolean hasAggregateFunction(String sql) {
if (!CollectionUtils.isEmpty(getFunctions(sql))) { if (!CollectionUtils.isEmpty(getFunctions(sql))) {
return true; return true;
@@ -84,5 +89,22 @@ public class SqlSelectFunctionHelper {
return sumFunction; return sumFunction;
} }
public static String getFirstAggregateFunctions(String expr) {
List<String> functions = getAggregateFunctions(expr);
return CollectionUtils.isEmpty(functions) ? "" : functions.get(0);
}
public static List<String> getAggregateFunctions(String expr) {
Expression expression = QueryExpressionReplaceVisitor.getExpression(expr);
if (Objects.nonNull(expression)) {
FunctionVisitor visitor = new FunctionVisitor();
expression.accept(visitor);
Set<String> functions = visitor.getFunctionNames();
return functions.stream()
.filter(t -> aggregateFunctionName.contains(t.toUpperCase())).collect(Collectors.toList());
}
return new ArrayList<>();
}
} }

View File

@@ -18,6 +18,7 @@ import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.ChatGptHelper; import com.tencent.supersonic.common.util.ChatGptHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.Measure; import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MeasureParam; import com.tencent.supersonic.headless.api.pojo.MeasureParam;
@@ -64,14 +65,6 @@ import com.tencent.supersonic.headless.server.service.TagMetaService;
import com.tencent.supersonic.headless.server.utils.MetricCheckUtils; import com.tencent.supersonic.headless.server.utils.MetricCheckUtils;
import com.tencent.supersonic.headless.server.utils.MetricConverter; import com.tencent.supersonic.headless.server.utils.MetricConverter;
import com.tencent.supersonic.headless.server.utils.ModelClusterBuilder; import com.tencent.supersonic.headless.server.utils.ModelClusterBuilder;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
@@ -85,6 +78,13 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service @Service
@Slf4j @Slf4j
@@ -109,14 +109,14 @@ public class MetricServiceImpl implements MetricService {
private MetaDiscoveryService metaDiscoveryService; private MetaDiscoveryService metaDiscoveryService;
public MetricServiceImpl(MetricRepository metricRepository, public MetricServiceImpl(MetricRepository metricRepository,
ModelService modelService, ModelService modelService,
ChatGptHelper chatGptHelper, ChatGptHelper chatGptHelper,
CollectService collectService, CollectService collectService,
DataSetService dataSetService, DataSetService dataSetService,
ApplicationEventPublisher eventPublisher, ApplicationEventPublisher eventPublisher,
DimensionService dimensionService, DimensionService dimensionService,
TagMetaService tagMetaService, TagMetaService tagMetaService,
@Lazy MetaDiscoveryService metaDiscoveryService) { @Lazy MetaDiscoveryService metaDiscoveryService) {
this.metricRepository = metricRepository; this.metricRepository = metricRepository;
this.modelService = modelService; this.modelService = modelService;
this.chatGptHelper = chatGptHelper; this.chatGptHelper = chatGptHelper;
@@ -331,7 +331,7 @@ public class MetricServiceImpl implements MetricService {
List<Long> idsToFilter = getIdsToFilter(pageMetricReq, collectIds); List<Long> idsToFilter = getIdsToFilter(pageMetricReq, collectIds);
metricFilter.setIds(idsToFilter); metricFilter.setIds(idsToFilter);
PageInfo<MetricDO> metricDOPageInfo = PageHelper.startPage(pageMetricReq.getCurrent(), PageInfo<MetricDO> metricDOPageInfo = PageHelper.startPage(pageMetricReq.getCurrent(),
pageMetricReq.getPageSize()) pageMetricReq.getPageSize())
.doSelectPageInfo(() -> queryMetric(metricFilter)); .doSelectPageInfo(() -> queryMetric(metricFilter));
PageInfo<MetricResp> pageInfo = new PageInfo<>(); PageInfo<MetricResp> pageInfo = new PageInfo<>();
BeanUtils.copyProperties(metricDOPageInfo, pageInfo); BeanUtils.copyProperties(metricDOPageInfo, pageInfo);
@@ -434,7 +434,7 @@ public class MetricServiceImpl implements MetricService {
} }
private boolean filterByField(List<MetricResp> metricResps, MetricResp metricResp, private boolean filterByField(List<MetricResp> metricResps, MetricResp metricResp,
List<String> fields, Set<MetricResp> metricRespFiltered) { List<String> fields, Set<MetricResp> metricRespFiltered) {
if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) { if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) {
List<Long> ids = metricResp.getMetricDefineByMetricParams().getMetrics() List<Long> ids = metricResp.getMetricDefineByMetricParams().getMetrics()
.stream().map(MetricParam::getId).collect(Collectors.toList()); .stream().map(MetricParam::getId).collect(Collectors.toList());
@@ -472,8 +472,8 @@ public class MetricServiceImpl implements MetricService {
metricFilter.setModelIds(Lists.newArrayList(modelId)); metricFilter.setModelIds(Lists.newArrayList(modelId));
List<MetricResp> metricResps = getMetrics(metricFilter); List<MetricResp> metricResps = getMetrics(metricFilter);
return metricResps.stream().filter(metricResp -> return metricResps.stream().filter(metricResp ->
MetricDefineType.FIELD.equals(metricResp.getMetricDefineType()) MetricDefineType.FIELD.equals(metricResp.getMetricDefineType())
|| MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) || MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType()))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@@ -700,9 +700,7 @@ public class MetricServiceImpl implements MetricService {
public void batchFillMetricDefaultAgg(List<MetricResp> metricResps, List<ModelResp> modelResps) { public void batchFillMetricDefaultAgg(List<MetricResp> metricResps, List<ModelResp> modelResps) {
Map<Long, ModelResp> modelRespMap = modelResps.stream().collect(Collectors.toMap(ModelResp::getId, m -> m)); Map<Long, ModelResp> modelRespMap = modelResps.stream().collect(Collectors.toMap(ModelResp::getId, m -> m));
for (MetricResp metricResp : metricResps) { for (MetricResp metricResp : metricResps) {
if (MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) { fillDefaultAgg(metricResp, modelRespMap.get(metricResp.getModelId()));
fillDefaultAgg(metricResp, modelRespMap.get(metricResp.getModelId()));
}
} }
} }
@@ -715,9 +713,28 @@ public class MetricServiceImpl implements MetricService {
} }
private void fillDefaultAgg(MetricResp metricResp, ModelResp modelResp) { private void fillDefaultAgg(MetricResp metricResp, ModelResp modelResp) {
if (modelResp == null) { if (modelResp == null || (Objects.nonNull(metricResp.getDefaultAgg()) && !metricResp.getDefaultAgg()
.isEmpty())) {
return; return;
} }
// FIELD define will get from expr
if (MetricDefineType.FIELD.equals(metricResp.getMetricDefineType())) {
metricResp.setDefaultAgg(SqlSelectFunctionHelper.getFirstAggregateFunctions(metricResp.getExpr()));
return;
}
// METRIC define will get from first metric
if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) {
if (!CollectionUtils.isEmpty(
metricResp.getMetricDefineByMetricParams().getMetrics())) {
MetricParam metricParam = metricResp.getMetricDefineByMetricParams().getMetrics().get(0);
MetricResp firstMetricResp = getMetric(modelResp.getDomainId(), metricParam.getBizName());
if (Objects.nonNull(firstMetricResp)) {
fillDefaultAgg(firstMetricResp, modelResp);
}
}
return;
}
// Measure define will get from first measure
List<Measure> measures = modelResp.getModelDetail().getMeasures(); List<Measure> measures = modelResp.getModelDetail().getMeasures();
MeasureParam firstMeasure = metricResp.getMetricDefineByMeasureParams() MeasureParam firstMeasure = metricResp.getMetricDefineByMeasureParams()
.getMeasures().get(0); .getMeasures().get(0);
@@ -817,7 +834,7 @@ public class MetricServiceImpl implements MetricService {
} }
private Set<Long> getModelIds(Set<Long> modelIdsByDomainId, List<MetricResp> metricResps, private Set<Long> getModelIds(Set<Long> modelIdsByDomainId, List<MetricResp> metricResps,
List<DimensionResp> dimensionResps) { List<DimensionResp> dimensionResps) {
Set<Long> result = new HashSet<>(); Set<Long> result = new HashSet<>();
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(modelIdsByDomainId)) { if (org.apache.commons.collections.CollectionUtils.isNotEmpty(modelIdsByDomainId)) {
result.addAll(modelIdsByDomainId); result.addAll(modelIdsByDomainId);