diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java index a2cf6c91e..a3c23ce1a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java @@ -13,6 +13,7 @@ import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.headless.api.pojo.DataSetDetail; +import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; import com.tencent.supersonic.headless.api.pojo.MetaFilter; import com.tencent.supersonic.headless.api.pojo.QueryConfig; import com.tencent.supersonic.headless.api.pojo.request.*; @@ -52,6 +53,9 @@ public class DataSetServiceImpl extends ServiceImpl @Autowired private MetricService metricService; + @Autowired + private ModelService modelService; + @Override public DataSetResp save(DataSetReq dataSetReq, User user) { dataSetReq.createdBy(user.getName()); @@ -77,7 +81,13 @@ public class DataSetServiceImpl extends ServiceImpl @Override public DataSetResp getDataSet(Long id) { DataSetDO dataSetDO = getById(id); - return convert(dataSetDO); + DataSetResp dataSetResp = convert(dataSetDO); + + if (dataSetResp.getDataSetDetail() != null) { + expandIncludesAllModels(dataSetResp); + } + + return dataSetResp; } @Override @@ -273,6 +283,59 @@ public class DataSetServiceImpl extends ServiceImpl .map(Object::toString).collect(Collectors.toList()); } + private void expandIncludesAllModels(DataSetResp dataSetResp) { + List configs = dataSetResp.getDataSetDetail().getDataSetModelConfigs(); + if (CollectionUtils.isEmpty(configs)) { + return; + } + + Set includeAllModelIds = configs.stream() + .filter(DataSetModelConfig::getIncludesAll) + .map(DataSetModelConfig::getId) + .collect(Collectors.toSet()); + + if (CollectionUtils.isEmpty(includeAllModelIds)) { + return; + } + + MetaFilter metaFilter = new MetaFilter(); + metaFilter.setModelIds(new ArrayList<>(includeAllModelIds)); + metaFilter.setStatus(StatusEnum.ONLINE.getCode()); + + List allDimensions = dimensionService.getDimensions(metaFilter); + List allMetrics = metricService.getMetrics(metaFilter); + + Map> modelDimensionMap = allDimensions.stream() + .collect(Collectors.groupingBy( + DimensionResp::getModelId, + Collectors.mapping(DimensionResp::getId, Collectors.toList()) + )); + + Map> modelMetricMap = allMetrics.stream() + .collect(Collectors.groupingBy( + MetricResp::getModelId, + Collectors.mapping(MetricResp::getId, Collectors.toList()) + )); + + for (DataSetModelConfig config : configs) { + if (Boolean.TRUE.equals(config.getIncludesAll())) { + Long modelId = config.getId(); + + List modelDimensions = modelDimensionMap.getOrDefault(modelId, Lists.newArrayList()); + Set existingDimensions = new HashSet<>(config.getDimensions()); + existingDimensions.addAll(modelDimensions); + config.setDimensions(new ArrayList<>(existingDimensions)); + + List modelMetrics = modelMetricMap.getOrDefault(modelId, Lists.newArrayList()); + Set existingMetrics = new HashSet<>(config.getMetrics()); + existingMetrics.addAll(modelMetrics); + config.setMetrics(new ArrayList<>(existingMetrics)); + + config.setIncludesAll(false); + } + } + } + public Long getDataSetIdFromSql(String sql, User user) { List dataSets = null; try {