diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectFunctionHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectFunctionHelper.java index 5a5654b8a..8836e94c2 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectFunctionHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectFunctionHelper.java @@ -1,11 +1,14 @@ package com.tencent.supersonic.common.util.jsqlparser; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; @@ -23,6 +26,8 @@ import org.springframework.util.CollectionUtils; @Slf4j public class SqlSelectFunctionHelper { + public static List aggregateFunctionName = Arrays.asList("SUM", "COUNT", "MAX", "MIN", "AVG"); + public static boolean hasAggregateFunction(String sql) { if (!CollectionUtils.isEmpty(getFunctions(sql))) { return true; @@ -84,5 +89,22 @@ public class SqlSelectFunctionHelper { return sumFunction; } + public static String getFirstAggregateFunctions(String expr) { + List functions = getAggregateFunctions(expr); + return CollectionUtils.isEmpty(functions) ? "" : functions.get(0); + } + + public static List getAggregateFunctions(String expr) { + Expression expression = QueryExpressionReplaceVisitor.getExpression(expr); + if (Objects.nonNull(expression)) { + FunctionVisitor visitor = new FunctionVisitor(); + expression.accept(visitor); + Set functions = visitor.getFunctionNames(); + return functions.stream() + .filter(t -> aggregateFunctionName.contains(t.toUpperCase())).collect(Collectors.toList()); + } + return new ArrayList<>(); + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java index 31f134a95..99aa6c1e6 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java @@ -18,6 +18,7 @@ import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.ChatGptHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.Measure; import com.tencent.supersonic.headless.api.pojo.MeasureParam; @@ -64,14 +65,6 @@ import com.tencent.supersonic.headless.server.service.TagMetaService; import com.tencent.supersonic.headless.server.utils.MetricCheckUtils; import com.tencent.supersonic.headless.server.utils.MetricConverter; import com.tencent.supersonic.headless.server.utils.ModelClusterBuilder; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.BeanUtils; -import org.springframework.context.ApplicationEventPublisher; -import org.springframework.context.annotation.Lazy; -import org.springframework.stereotype.Service; -import org.springframework.util.CollectionUtils; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -85,6 +78,13 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; @Service @Slf4j @@ -109,14 +109,14 @@ public class MetricServiceImpl implements MetricService { private MetaDiscoveryService metaDiscoveryService; public MetricServiceImpl(MetricRepository metricRepository, - ModelService modelService, - ChatGptHelper chatGptHelper, - CollectService collectService, - DataSetService dataSetService, - ApplicationEventPublisher eventPublisher, - DimensionService dimensionService, - TagMetaService tagMetaService, - @Lazy MetaDiscoveryService metaDiscoveryService) { + ModelService modelService, + ChatGptHelper chatGptHelper, + CollectService collectService, + DataSetService dataSetService, + ApplicationEventPublisher eventPublisher, + DimensionService dimensionService, + TagMetaService tagMetaService, + @Lazy MetaDiscoveryService metaDiscoveryService) { this.metricRepository = metricRepository; this.modelService = modelService; this.chatGptHelper = chatGptHelper; @@ -331,7 +331,7 @@ public class MetricServiceImpl implements MetricService { List idsToFilter = getIdsToFilter(pageMetricReq, collectIds); metricFilter.setIds(idsToFilter); PageInfo metricDOPageInfo = PageHelper.startPage(pageMetricReq.getCurrent(), - pageMetricReq.getPageSize()) + pageMetricReq.getPageSize()) .doSelectPageInfo(() -> queryMetric(metricFilter)); PageInfo pageInfo = new PageInfo<>(); BeanUtils.copyProperties(metricDOPageInfo, pageInfo); @@ -434,7 +434,7 @@ public class MetricServiceImpl implements MetricService { } private boolean filterByField(List metricResps, MetricResp metricResp, - List fields, Set metricRespFiltered) { + List fields, Set metricRespFiltered) { if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) { List ids = metricResp.getMetricDefineByMetricParams().getMetrics() .stream().map(MetricParam::getId).collect(Collectors.toList()); @@ -472,8 +472,8 @@ public class MetricServiceImpl implements MetricService { metricFilter.setModelIds(Lists.newArrayList(modelId)); List metricResps = getMetrics(metricFilter); return metricResps.stream().filter(metricResp -> - MetricDefineType.FIELD.equals(metricResp.getMetricDefineType()) - || MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) + MetricDefineType.FIELD.equals(metricResp.getMetricDefineType()) + || MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) .collect(Collectors.toList()); } @@ -700,9 +700,7 @@ public class MetricServiceImpl implements MetricService { public void batchFillMetricDefaultAgg(List metricResps, List modelResps) { Map modelRespMap = modelResps.stream().collect(Collectors.toMap(ModelResp::getId, m -> m)); for (MetricResp metricResp : metricResps) { - if (MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) { - fillDefaultAgg(metricResp, modelRespMap.get(metricResp.getModelId())); - } + fillDefaultAgg(metricResp, modelRespMap.get(metricResp.getModelId())); } } @@ -715,9 +713,28 @@ public class MetricServiceImpl implements MetricService { } private void fillDefaultAgg(MetricResp metricResp, ModelResp modelResp) { - if (modelResp == null) { + if (modelResp == null || (Objects.nonNull(metricResp.getDefaultAgg()) && !metricResp.getDefaultAgg() + .isEmpty())) { return; } + // FIELD define will get from expr + if (MetricDefineType.FIELD.equals(metricResp.getMetricDefineType())) { + metricResp.setDefaultAgg(SqlSelectFunctionHelper.getFirstAggregateFunctions(metricResp.getExpr())); + return; + } + // METRIC define will get from first metric + if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) { + if (!CollectionUtils.isEmpty( + metricResp.getMetricDefineByMetricParams().getMetrics())) { + MetricParam metricParam = metricResp.getMetricDefineByMetricParams().getMetrics().get(0); + MetricResp firstMetricResp = getMetric(modelResp.getDomainId(), metricParam.getBizName()); + if (Objects.nonNull(firstMetricResp)) { + fillDefaultAgg(firstMetricResp, modelResp); + } + } + return; + } + // Measure define will get from first measure List measures = modelResp.getModelDetail().getMeasures(); MeasureParam firstMeasure = metricResp.getMetricDefineByMeasureParams() .getMeasures().get(0); @@ -817,7 +834,7 @@ public class MetricServiceImpl implements MetricService { } private Set getModelIds(Set modelIdsByDomainId, List metricResps, - List dimensionResps) { + List dimensionResps) { Set result = new HashSet<>(); if (org.apache.commons.collections.CollectionUtils.isNotEmpty(modelIdsByDomainId)) { result.addAll(modelIdsByDomainId);