[improvement][headless]Introduce DerivedMetricConverter and optimize metric creation in S2VisitsDemo.

This commit is contained in:
jerryjzhang
2024-12-11 17:51:02 +08:00
parent 4062a13126
commit f97ac1da83
17 changed files with 205 additions and 291 deletions

View File

@@ -0,0 +1,115 @@
package com.tencent.supersonic.headless.core.translator.converter;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
import com.tencent.supersonic.headless.api.pojo.enums.MetricType;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.pojo.SqlQueryParam;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.OntologyQueryParam;
import com.tencent.supersonic.headless.core.utils.SqlGenerateUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.*;
@Component("DerivedMetricConverter")
@Slf4j
public class DerivedMetricConverter implements QueryConverter {
@Override
public boolean accept(QueryStatement queryStatement) {
return Objects.nonNull(queryStatement.getSqlQueryParam())
&& StringUtils.isNotBlank(queryStatement.getSqlQueryParam().getSql());
}
@Override
public void convert(QueryStatement queryStatement) throws Exception {
SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp();
SqlQueryParam sqlParam = queryStatement.getSqlQueryParam();
OntologyQueryParam ontologyParam = queryStatement.getOntologyQueryParam();
String sql = sqlParam.getSql();
Set<String> measures = new HashSet<>();
Map<String, String> replaces =
generateDerivedMetric(semanticSchemaResp, ontologyParam.getAggOption(),
ontologyParam.getMetrics(), ontologyParam.getDimensions(), measures);
if (!CollectionUtils.isEmpty(replaces)) {
// metricTable sql use measures replace metric
sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces);
ontologyParam.setAggOption(AggOption.NATIVE);
// metricTable use measures replace metric
if (!CollectionUtils.isEmpty(measures)) {
ontologyParam.getMetrics().addAll(measures);
}
}
sqlParam.setSql(sql);
queryStatement.setSql(queryStatement.getSqlQueryParam().getSql());
}
private Map<String, String> generateDerivedMetric(SemanticSchemaResp semanticSchemaResp,
AggOption aggOption, Set<String> metrics, Set<String> dimensions,
Set<String> measures) {
SqlGenerateUtils sqlGenerateUtils = ContextUtils.getBean(SqlGenerateUtils.class);
Map<String, String> result = new HashMap<>();
List<MetricSchemaResp> metricResps = semanticSchemaResp.getMetrics();
List<DimSchemaResp> dimensionResps = semanticSchemaResp.getDimensions();
// Check if any metric is derived
boolean hasDerivedMetrics =
metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType
.isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()));
if (!hasDerivedMetrics) {
return result;
}
log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics);
Set<String> allFields = new HashSet<>();
Map<String, Measure> allMeasures = new HashMap<>();
semanticSchemaResp.getModelResps().forEach(modelResp -> {
allFields.addAll(modelResp.getFieldList());
if (modelResp.getModelDetail().getMeasures() != null) {
modelResp.getModelDetail().getMeasures()
.forEach(measure -> allMeasures.put(measure.getBizName(), measure));
}
});
Set<String> derivedDimensions = new HashSet<>();
Set<String> derivedMetrics = new HashSet<>();
Map<String, String> visitedMetrics = new HashMap<>();
for (MetricResp metricResp : metricResps) {
if (metrics.contains(metricResp.getBizName())) {
boolean isDerived = MetricType.isDerived(metricResp.getMetricDefineType(),
metricResp.getMetricDefineByMeasureParams());
if (isDerived) {
String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields,
allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp),
metricResp.getMetricDefineType(), aggOption, visitedMetrics,
derivedMetrics, derivedDimensions);
result.put(metricResp.getBizName(), expr);
log.debug("derived metric {}->{}", metricResp.getBizName(), expr);
} else {
measures.add(metricResp.getBizName());
}
}
}
measures.addAll(derivedMetrics);
derivedDimensions.stream().filter(dimension -> !dimensions.contains(dimension))
.forEach(dimensions::add);
return result;
}
}

View File

@@ -7,11 +7,10 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.EngineType; import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.SchemaItem; import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption; import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
import com.tencent.supersonic.headless.api.pojo.enums.MetricType; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.*; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.pojo.QueryStatement; import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.pojo.SqlQueryParam; import com.tencent.supersonic.headless.core.pojo.SqlQueryParam;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.OntologyQueryParam; import com.tencent.supersonic.headless.core.translator.parser.s2sql.OntologyQueryParam;
@@ -20,7 +19,6 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -66,6 +64,7 @@ public class SqlQueryConverter implements QueryConverter {
ontologyQueryParam.getMetrics().addAll(metrics); ontologyQueryParam.getMetrics().addAll(metrics);
ontologyQueryParam.getDimensions().addAll(dimensions); ontologyQueryParam.getDimensions().addAll(dimensions);
AggOption sqlQueryAggOption = getAggOption(sqlQueryParam.getSql(), metricSchemas); AggOption sqlQueryAggOption = getAggOption(sqlQueryParam.getSql(), metricSchemas);
// if sql query itself has aggregation, ontology query just returns detail // if sql query itself has aggregation, ontology query just returns detail
if (sqlQueryAggOption.equals(AggOption.AGGREGATION)) { if (sqlQueryAggOption.equals(AggOption.AGGREGATION)) {
ontologyQueryParam.setAggOption(AggOption.NATIVE); ontologyQueryParam.setAggOption(AggOption.NATIVE);
@@ -74,9 +73,6 @@ public class SqlQueryConverter implements QueryConverter {
} }
ontologyQueryParam.setNativeQuery(!AggOption.isAgg(ontologyQueryParam.getAggOption())); ontologyQueryParam.setNativeQuery(!AggOption.isAgg(ontologyQueryParam.getAggOption()));
queryStatement.setOntologyQueryParam(ontologyQueryParam); queryStatement.setOntologyQueryParam(ontologyQueryParam);
generateDerivedMetric(sqlGenerateUtils, queryStatement);
queryStatement.setSql(sqlQueryParam.getSql()); queryStatement.setSql(sqlQueryParam.getSql());
log.info("parse sqlQuery [{}] ", sqlQueryParam); log.info("parse sqlQuery [{}] ", sqlQueryParam);
} }
@@ -138,92 +134,6 @@ public class SqlQueryConverter implements QueryConverter {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private void generateDerivedMetric(SqlGenerateUtils sqlGenerateUtils,
QueryStatement queryStatement) {
SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp();
SqlQueryParam sqlParam = queryStatement.getSqlQueryParam();
OntologyQueryParam ontologyParam = queryStatement.getOntologyQueryParam();
String sql = sqlParam.getSql();
Set<String> measures = new HashSet<>();
Map<String, String> replaces = generateDerivedMetric(sqlGenerateUtils, semanticSchemaResp,
ontologyParam.getAggOption(), ontologyParam.getMetrics(),
ontologyParam.getDimensions(), measures);
if (!CollectionUtils.isEmpty(replaces)) {
// metricTable sql use measures replace metric
sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces);
ontologyParam.setAggOption(AggOption.NATIVE);
// metricTable use measures replace metric
if (!CollectionUtils.isEmpty(measures)) {
ontologyParam.getMetrics().addAll(measures);
} else {
// empty measure , fill default
ontologyParam.getMetrics().add(sqlGenerateUtils.generateInternalMetricName(
getDefaultModel(semanticSchemaResp, ontologyParam.getDimensions())));
}
}
sqlParam.setSql(sql);
}
private Map<String, String> generateDerivedMetric(SqlGenerateUtils sqlGenerateUtils,
SemanticSchemaResp semanticSchemaResp, AggOption aggOption, Set<String> metrics,
Set<String> dimensions, Set<String> measures) {
Map<String, String> result = new HashMap<>();
List<MetricSchemaResp> metricResps = semanticSchemaResp.getMetrics();
List<DimSchemaResp> dimensionResps = semanticSchemaResp.getDimensions();
// Check if any metric is derived
boolean hasDerivedMetrics =
metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType
.isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()));
if (!hasDerivedMetrics) {
return result;
}
log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics);
Set<String> allFields = new HashSet<>();
Map<String, Measure> allMeasures = new HashMap<>();
semanticSchemaResp.getModelResps().forEach(modelResp -> {
allFields.addAll(modelResp.getFieldList());
if (modelResp.getModelDetail().getMeasures() != null) {
modelResp.getModelDetail().getMeasures()
.forEach(measure -> allMeasures.put(measure.getBizName(), measure));
}
});
Set<String> derivedDimensions = new HashSet<>();
Set<String> derivedMetrics = new HashSet<>();
Map<String, String> visitedMetrics = new HashMap<>();
for (MetricResp metricResp : metricResps) {
if (metrics.contains(metricResp.getBizName())) {
boolean isDerived = MetricType.isDerived(metricResp.getMetricDefineType(),
metricResp.getMetricDefineByMeasureParams());
if (isDerived) {
String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields,
allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp),
metricResp.getMetricDefineType(), aggOption, visitedMetrics,
derivedMetrics, derivedDimensions);
result.put(metricResp.getBizName(), expr);
log.debug("derived metric {}->{}", metricResp.getBizName(), expr);
} else {
measures.add(metricResp.getBizName());
}
}
}
measures.addAll(derivedMetrics);
derivedDimensions.stream().filter(dimension -> !dimensions.contains(dimension))
.forEach(dimensions::add);
return result;
}
private void convertNameToBizName(QueryStatement queryStatement) { private void convertNameToBizName(QueryStatement queryStatement) {
SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp(); SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp();
Map<String, String> fieldNameToBizNameMap = getFieldNameToBizNameMap(semanticSchemaResp); Map<String, String> fieldNameToBizNameMap = getFieldNameToBizNameMap(semanticSchemaResp);
@@ -276,18 +186,4 @@ public class SqlQueryConverter implements QueryConverter {
return elements.stream(); return elements.stream();
} }
private String getDefaultModel(SemanticSchemaResp semanticSchemaResp, Set<String> dimensions) {
if (!CollectionUtils.isEmpty(dimensions)) {
Map<String, Long> modelMatchCnt = new HashMap<>();
for (ModelResp modelResp : semanticSchemaResp.getModelResps()) {
modelMatchCnt.put(modelResp.getBizName(), modelResp.getModelDetail().getDimensions()
.stream().filter(d -> dimensions.contains(d.getBizName())).count());
}
return modelMatchCnt.entrySet().stream()
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder()))
.map(Map.Entry::getKey).findFirst().orElse("");
}
return semanticSchemaResp.getModelResps().get(0).getBizName();
}
} }

View File

@@ -140,8 +140,16 @@ public class DataModelNode extends SemanticNode {
Set<String> schemaMetricName = Set<String> schemaMetricName =
ontology.getMetrics().stream().map(Metric::getName).collect(Collectors.toSet()); ontology.getMetrics().stream().map(Metric::getName).collect(Collectors.toSet());
ontology.getMetrics().stream().filter(m -> queryParam.getMetrics().contains(m.getName())) ontology.getMetrics().stream().filter(m -> queryParam.getMetrics().contains(m.getName()))
.forEach(m -> m.getMetricTypeParams().getMeasures() .forEach(m -> {
.forEach(mm -> queryMeasures.add(mm.getName()))); if (!CollectionUtils.isEmpty(m.getMetricTypeParams().getMeasures())) {
m.getMetricTypeParams().getMeasures()
.forEach(mm -> queryMeasures.add(mm.getName()));
}
if (!CollectionUtils.isEmpty(m.getMetricTypeParams().getFields())) {
m.getMetricTypeParams().getFields()
.forEach(mm -> queryMeasures.add(mm.getName()));
}
});
queryParam.getMetrics().stream().filter(m -> !schemaMetricName.contains(m)) queryParam.getMetrics().stream().filter(m -> !schemaMetricName.contains(m))
.forEach(queryMeasures::add); .forEach(queryMeasures::add);
} }

View File

@@ -62,22 +62,19 @@ public class DimensionServiceImpl extends ServiceImpl<DimensionDOMapper, Dimensi
private DataSetService dataSetService; private DataSetService dataSetService;
private TagMetaService tagMetaService;
@Autowired @Autowired
private ApplicationEventPublisher eventPublisher; private ApplicationEventPublisher eventPublisher;
public DimensionServiceImpl(DimensionRepository dimensionRepository, ModelService modelService, public DimensionServiceImpl(DimensionRepository dimensionRepository, ModelService modelService,
AliasGenerateHelper aliasGenerateHelper, DatabaseService databaseService, AliasGenerateHelper aliasGenerateHelper, DatabaseService databaseService,
ModelRelaService modelRelaService, DataSetService dataSetService, ModelRelaService modelRelaService, DataSetService dataSetService) {
TagMetaService tagMetaService) {
this.modelService = modelService; this.modelService = modelService;
this.dimensionRepository = dimensionRepository; this.dimensionRepository = dimensionRepository;
this.aliasGenerateHelper = aliasGenerateHelper; this.aliasGenerateHelper = aliasGenerateHelper;
this.databaseService = databaseService; this.databaseService = databaseService;
this.modelRelaService = modelRelaService; this.modelRelaService = modelRelaService;
this.dataSetService = dataSetService; this.dataSetService = dataSetService;
this.tagMetaService = tagMetaService;
} }
@Override @Override

View File

@@ -59,15 +59,12 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
private ApplicationEventPublisher eventPublisher; private ApplicationEventPublisher eventPublisher;
private TagMetaService tagMetaService;
private ChatLayerService chatLayerService; private ChatLayerService chatLayerService;
public MetricServiceImpl(MetricRepository metricRepository, ModelService modelService, public MetricServiceImpl(MetricRepository metricRepository, ModelService modelService,
AliasGenerateHelper aliasGenerateHelper, CollectService collectService, AliasGenerateHelper aliasGenerateHelper, CollectService collectService,
DataSetService dataSetService, ApplicationEventPublisher eventPublisher, DataSetService dataSetService, ApplicationEventPublisher eventPublisher,
DimensionService dimensionService, TagMetaService tagMetaService, DimensionService dimensionService, @Lazy ChatLayerService chatLayerService) {
@Lazy ChatLayerService chatLayerService) {
this.metricRepository = metricRepository; this.metricRepository = metricRepository;
this.modelService = modelService; this.modelService = modelService;
this.aliasGenerateHelper = aliasGenerateHelper; this.aliasGenerateHelper = aliasGenerateHelper;
@@ -75,7 +72,6 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
this.collectService = collectService; this.collectService = collectService;
this.dataSetService = dataSetService; this.dataSetService = dataSetService;
this.dimensionService = dimensionService; this.dimensionService = dimensionService;
this.tagMetaService = tagMetaService;
this.chatLayerService = chatLayerService; this.chatLayerService = chatLayerService;
} }

View File

@@ -46,9 +46,9 @@ public class MetricCheckUtils {
throw new InvalidArgumentException("指标定义参数不可为空"); throw new InvalidArgumentException("指标定义参数不可为空");
} }
expr = typeParams.getExpr(); expr = typeParams.getExpr();
if (CollectionUtils.isEmpty(typeParams.getFields())) { // if (CollectionUtils.isEmpty(typeParams.getFields())) {
throw new InvalidArgumentException("定义指标的字段列表参数不可为空"); // throw new InvalidArgumentException("定义指标的字段列表参数不可为空");
} // }
if (!hasAggregateFunction(expr)) { if (!hasAggregateFunction(expr)) {
throw new InvalidArgumentException("基于字段来创建指标,表达式中必须包含聚合函数"); throw new InvalidArgumentException("基于字段来创建指标,表达式中必须包含聚合函数");
} }

View File

@@ -70,11 +70,9 @@ public class MetricServiceImplTest {
ApplicationEventPublisher eventPublisher = Mockito.mock(ApplicationEventPublisher.class); ApplicationEventPublisher eventPublisher = Mockito.mock(ApplicationEventPublisher.class);
DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class); DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class);
DimensionService dimensionService = Mockito.mock(DimensionService.class); DimensionService dimensionService = Mockito.mock(DimensionService.class);
TagMetaService tagMetaService = Mockito.mock(TagMetaService.class);
ChatLayerService chatLayerService = Mockito.mock(ChatLayerService.class); ChatLayerService chatLayerService = Mockito.mock(ChatLayerService.class);
return new MetricServiceImpl(metricRepository, modelService, aliasGenerateHelper, return new MetricServiceImpl(metricRepository, modelService, aliasGenerateHelper,
collectService, dataSetService, eventPublisher, dimensionService, tagMetaService, collectService, dataSetService, eventPublisher, dimensionService, chatLayerService);
chatLayerService);
} }
private MetricReq buildMetricReq() { private MetricReq buildMetricReq() {

View File

@@ -30,6 +30,7 @@ com.tencent.supersonic.headless.core.translator.converter.QueryConverter=\
com.tencent.supersonic.headless.core.translator.converter.SqlVariableConverter,\ com.tencent.supersonic.headless.core.translator.converter.SqlVariableConverter,\
com.tencent.supersonic.headless.core.translator.converter.SqlQueryConverter,\ com.tencent.supersonic.headless.core.translator.converter.SqlQueryConverter,\
com.tencent.supersonic.headless.core.translator.converter.StructQueryConverter,\ com.tencent.supersonic.headless.core.translator.converter.StructQueryConverter,\
com.tencent.supersonic.headless.core.translator.converter.DerivedMetricConverter,\
com.tencent.supersonic.headless.core.translator.converter.MetricRatioConverter com.tencent.supersonic.headless.core.translator.converter.MetricRatioConverter
com.tencent.supersonic.headless.core.translator.optimizer.QueryOptimizer=\ com.tencent.supersonic.headless.core.translator.optimizer.QueryOptimizer=\

View File

@@ -16,54 +16,21 @@ import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuer
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.*;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail; import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.Dimension;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.Field;
import com.tencent.supersonic.headless.api.pojo.FieldParam;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByFieldParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMetricParams;
import com.tencent.supersonic.headless.api.pojo.MetricParam;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType; import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType; import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType; import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq; import com.tencent.supersonic.headless.api.pojo.request.*;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; import com.tencent.supersonic.headless.api.pojo.response.*;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.TermReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.ArrayList; import java.util.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Component @Component
@Slf4j @Slf4j
@@ -87,14 +54,13 @@ public class S2VisitsDemo extends S2BaseDemo {
// create metrics and dimensions // create metrics and dimensions
DimensionResp departmentDimension = getDimension("department", userModel); DimensionResp departmentDimension = getDimension("department", userModel);
MetricResp metricUv = addMetric_uv(pvUvModel, departmentDimension); MetricResp metricUv = addMetric_uv(pvUvModel, departmentDimension);
MetricResp metricPv = getMetric("pv", pvUvModel);
addMetric_pv_avg(metricPv, metricUv, departmentDimension, pvUvModel);
DimensionResp pageDimension = getDimension("page", stayTimeModel); DimensionResp pageDimension = getDimension("page", stayTimeModel);
updateDimension(stayTimeModel, pageDimension); updateDimension(stayTimeModel, pageDimension);
DimensionResp userDimension = getDimension("user_name", userModel); DimensionResp userDimension = getDimension("user_name", userModel);
updateMetric(stayTimeModel, departmentDimension, userDimension); MetricResp metricPv = addMetric_pv(pvUvModel, departmentDimension, userDimension);
updateMetric_pv(pvUvModel, departmentDimension, userDimension, metricPv);
addMetric_pv_avg(metricPv, metricUv, departmentDimension, pvUvModel);
// create dict conf for dimensions // create dict conf for dimensions
enableDimensionValue(departmentDimension); enableDimensionValue(departmentDimension);
@@ -103,7 +69,7 @@ public class S2VisitsDemo extends S2BaseDemo {
// create data set // create data set
DataSetResp s2DataSet = addDataSet(s2Domain); DataSetResp s2DataSet = addDataSet(s2Domain);
addAuthGroup_1(stayTimeModel); addAuthGroup_1(stayTimeModel);
addAuthGroup_2(pvUvModel); addAuthGroup_2(stayTimeModel);
// create terms and plugin // create terms and plugin
addTerm(s2Domain); addTerm(s2Domain);
@@ -196,12 +162,11 @@ public class S2VisitsDemo extends S2BaseDemo {
modelReq.setAdminOrgs(Collections.emptyList()); modelReq.setAdminOrgs(Collections.emptyList());
ModelDetail modelDetail = new ModelDetail(); ModelDetail modelDetail = new ModelDetail();
List<Identify> identifiers = new ArrayList<>(); List<Identify> identifiers = new ArrayList<>();
identifiers.add(new Identify("用户", IdentifyType.primary.name(), "user_name", 1)); identifiers.add(new Identify("用户", IdentifyType.primary.name(), "user_name", 1));
modelDetail.setIdentifiers(identifiers); modelDetail.setIdentifiers(identifiers);
List<Dimension> dimensions = new ArrayList<>(); List<Dimension> dimensions = new ArrayList<>();
dimensions.add(new Dimension("部门", "department", DimensionType.categorical, 1)); dimensions.add(new Dimension("部门", "department", DimensionType.categorical, 1));
// dimensions.add(new Dimension("用户", "user_name", DimensionType.categorical, 1));
modelDetail.setDimensions(dimensions); modelDetail.setDimensions(dimensions);
List<Field> fields = Lists.newArrayList(); List<Field> fields = Lists.newArrayList();
fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build()); fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build());
@@ -209,7 +174,7 @@ public class S2VisitsDemo extends S2BaseDemo {
modelDetail.setFields(fields); modelDetail.setFields(fields);
modelDetail.setMeasures(Collections.emptyList()); modelDetail.setMeasures(Collections.emptyList());
modelDetail.setQueryType("sql_query"); modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("select user_name,department from s2_user_department"); modelDetail.setSqlQuery("select * from s2_user_department");
modelReq.setModelDetail(modelDetail); modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, defaultUser); return modelService.createModel(modelReq, defaultUser);
} }
@@ -238,21 +203,12 @@ public class S2VisitsDemo extends S2BaseDemo {
dimension2.setExpr("page"); dimension2.setExpr("page");
dimensions.add(dimension2); dimensions.add(dimension2);
modelDetail.setDimensions(dimensions); modelDetail.setDimensions(dimensions);
List<Measure> measures = new ArrayList<>();
Measure measure1 = new Measure("访问次数", "pv", AggOperatorEnum.SUM.name(), 1);
measures.add(measure1);
Measure measure2 = new Measure("访问用户数", "user_id", AggOperatorEnum.SUM.name(), 0);
measures.add(measure2);
modelDetail.setMeasures(measures);
List<Field> fields = Lists.newArrayList(); List<Field> fields = Lists.newArrayList();
fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build()); fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build());
fields.add(Field.builder().fieldName("imp_date").dataType("Date").build()); fields.add(Field.builder().fieldName("imp_date").dataType("Date").build());
fields.add(Field.builder().fieldName("page").dataType("Varchar").build()); fields.add(Field.builder().fieldName("page").dataType("Varchar").build());
fields.add(Field.builder().fieldName("pv").dataType("Long").build());
fields.add(Field.builder().fieldName("user_id").dataType("Varchar").build());
modelDetail.setFields(fields); modelDetail.setFields(fields);
modelDetail.setSqlQuery("SELECT imp_date, user_name, page, 1 as pv, " modelDetail.setSqlQuery("SELECT * FROM s2_pv_uv_statis");
+ "user_name as user_id FROM s2_pv_uv_statis");
modelDetail.setQueryType("sql_query"); modelDetail.setQueryType("sql_query");
modelReq.setModelDetail(modelDetail); modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, defaultUser); return modelService.createModel(modelReq, defaultUser);
@@ -265,13 +221,13 @@ public class S2VisitsDemo extends S2BaseDemo {
modelReq.setDescription("停留时长统计"); modelReq.setDescription("停留时长统计");
modelReq.setDomainId(s2Domain.getId()); modelReq.setDomainId(s2Domain.getId());
modelReq.setDatabaseId(s2Database.getId()); modelReq.setDatabaseId(s2Database.getId());
modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); modelReq.setViewers(Arrays.asList("admin", "jack"));
modelReq.setViewOrgs(Collections.singletonList("1")); modelReq.setViewOrgs(Collections.singletonList("1"));
modelReq.setAdmins(Collections.singletonList("admin")); modelReq.setAdmins(Collections.singletonList("admin"));
modelReq.setAdminOrgs(Collections.emptyList()); modelReq.setAdminOrgs(Collections.emptyList());
List<Identify> identifiers = new ArrayList<>(); List<Identify> identifiers = new ArrayList<>();
ModelDetail modelDetail = new ModelDetail(); ModelDetail modelDetail = new ModelDetail();
identifiers.add(new Identify("用户", IdentifyType.foreign.name(), "user_name", 0)); identifiers.add(new Identify("用户", IdentifyType.foreign.name(), "user_name", 0));
modelDetail.setIdentifiers(identifiers); modelDetail.setIdentifiers(identifiers);
List<Dimension> dimensions = new ArrayList<>(); List<Dimension> dimensions = new ArrayList<>();
@@ -293,8 +249,7 @@ public class S2VisitsDemo extends S2BaseDemo {
fields.add(Field.builder().fieldName("page").dataType("Varchar").build()); fields.add(Field.builder().fieldName("page").dataType("Varchar").build());
fields.add(Field.builder().fieldName("stay_hours").dataType("Double").build()); fields.add(Field.builder().fieldName("stay_hours").dataType("Double").build());
modelDetail.setFields(fields); modelDetail.setFields(fields);
modelDetail modelDetail.setSqlQuery("select * from s2_stay_time_statis");
.setSqlQuery("select imp_date,user_name,stay_hours,page from s2_stay_time_statis");
modelDetail.setQueryType("sql_query"); modelDetail.setQueryType("sql_query");
modelReq.setModelDetail(modelDetail); modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, defaultUser); return modelService.createModel(modelReq, defaultUser);
@@ -329,51 +284,23 @@ public class S2VisitsDemo extends S2BaseDemo {
dimensionService.updateDimension(dimensionReq, defaultUser); dimensionService.updateDimension(dimensionReq, defaultUser);
} }
private void updateMetric(ModelResp stayTimeModel, DimensionResp departmentDimension, private MetricResp addMetric_pv(ModelResp pvUvModel, DimensionResp departmentDimension,
DimensionResp userDimension) throws Exception { DimensionResp userDimension) throws Exception {
MetricResp stayHoursMetric = metricService.getMetric(stayTimeModel.getId(), "stay_hours");
MetricReq metricReq = new MetricReq();
metricReq.setModelId(stayTimeModel.getId());
metricReq.setId(stayHoursMetric.getId());
metricReq.setName("停留时长");
metricReq.setBizName("stay_hours");
metricReq.setSensitiveLevel(SensitiveLevelEnum.HIGH.getCode());
metricReq.setDescription("停留时长");
metricReq.setClassifications(Collections.singletonList("核心指标"));
MetricDefineByMeasureParams metricTypeParams = new MetricDefineByMeasureParams();
metricTypeParams.setExpr("s2_stay_time_statis_stay_hours");
List<MeasureParam> measures = new ArrayList<>();
MeasureParam measure = new MeasureParam("s2_stay_time_statis_stay_hours", "",
AggOperatorEnum.SUM.getOperator());
measures.add(measure);
metricTypeParams.setMeasures(measures);
metricReq.setMetricDefineByMeasureParams(metricTypeParams);
metricReq.setMetricDefineType(MetricDefineType.MEASURE);
metricReq.setRelateDimension(getRelateDimension(
Lists.newArrayList(departmentDimension.getId(), userDimension.getId())));
metricService.updateMetric(metricReq, defaultUser);
}
private void updateMetric_pv(ModelResp pvUvModel, DimensionResp departmentDimension,
DimensionResp userDimension, MetricResp metricPv) throws Exception {
MetricReq metricReq = new MetricReq(); MetricReq metricReq = new MetricReq();
metricReq.setModelId(pvUvModel.getId()); metricReq.setModelId(pvUvModel.getId());
metricReq.setId(metricPv.getId());
metricReq.setName("访问次数"); metricReq.setName("访问次数");
metricReq.setBizName("pv"); metricReq.setBizName("pv");
metricReq.setDescription("一段时间内用户的访问次数"); metricReq.setDescription("一段时间内用户的访问次数");
MetricDefineByMeasureParams metricTypeParams = new MetricDefineByMeasureParams(); MetricDefineByFieldParams metricTypeParams = new MetricDefineByFieldParams();
metricTypeParams.setExpr("s2_pv_uv_statis_pv"); metricTypeParams.setExpr("count(imp_date)");
List<MeasureParam> measures = new ArrayList<>(); List<FieldParam> fieldParams = new ArrayList<>();
MeasureParam measure = fieldParams.add(new FieldParam("imp_date"));
new MeasureParam("s2_pv_uv_statis_pv", "", AggOperatorEnum.SUM.getOperator()); metricTypeParams.setFields(fieldParams);
measures.add(measure); metricReq.setMetricDefineByFieldParams(metricTypeParams);
metricTypeParams.setMeasures(measures); metricReq.setMetricDefineType(MetricDefineType.FIELD);
metricReq.setMetricDefineByMeasureParams(metricTypeParams);
metricReq.setMetricDefineType(MetricDefineType.MEASURE);
metricReq.setRelateDimension(getRelateDimension( metricReq.setRelateDimension(getRelateDimension(
Lists.newArrayList(departmentDimension.getId(), userDimension.getId()))); Lists.newArrayList(departmentDimension.getId(), userDimension.getId())));
metricService.updateMetric(metricReq, defaultUser); return metricService.createMetric(metricReq, defaultUser);
} }
private MetricResp addMetric_uv(ModelResp uvModel, DimensionResp departmentDimension) private MetricResp addMetric_uv(ModelResp uvModel, DimensionResp departmentDimension)
@@ -470,15 +397,14 @@ public class S2VisitsDemo extends S2BaseDemo {
authService.addOrUpdateAuthGroup(authGroupReq); authService.addOrUpdateAuthGroup(authGroupReq);
} }
private void addAuthGroup_2(ModelResp pvuvModel) { private void addAuthGroup_2(ModelResp model) {
AuthGroup authGroupReq = new AuthGroup(); AuthGroup authGroupReq = new AuthGroup();
authGroupReq.setModelId(pvuvModel.getId()); authGroupReq.setModelId(model.getId());
authGroupReq.setName("tom_row_permission"); authGroupReq.setName("tom_row_permission");
List<AuthRule> authRules = new ArrayList<>(); List<AuthRule> authRules = new ArrayList<>();
authGroupReq.setAuthRules(authRules); authGroupReq.setAuthRules(authRules);
authGroupReq.setDimensionFilters(Collections.singletonList("user_name = 'tom'")); authGroupReq.setDimensionFilters(Collections.singletonList("user_name = 'tom'"));
authGroupReq.setDimensionFilterDescription("用户名='tom'");
authGroupReq.setAuthorizedUsers(Collections.singletonList("tom")); authGroupReq.setAuthorizedUsers(Collections.singletonList("tom"));
authGroupReq.setAuthorizedDepartmentIds(Collections.emptyList()); authGroupReq.setAuthorizedDepartmentIds(Collections.emptyList());
authService.addOrUpdateAuthGroup(authGroupReq); authService.addOrUpdateAuthGroup(authGroupReq);

View File

@@ -30,6 +30,7 @@ com.tencent.supersonic.headless.core.translator.converter.QueryConverter=\
com.tencent.supersonic.headless.core.translator.converter.SqlVariableConverter,\ com.tencent.supersonic.headless.core.translator.converter.SqlVariableConverter,\
com.tencent.supersonic.headless.core.translator.converter.SqlQueryConverter,\ com.tencent.supersonic.headless.core.translator.converter.SqlQueryConverter,\
com.tencent.supersonic.headless.core.translator.converter.StructQueryConverter,\ com.tencent.supersonic.headless.core.translator.converter.StructQueryConverter,\
com.tencent.supersonic.headless.core.translator.converter.DerivedMetricConverter,\
com.tencent.supersonic.headless.core.translator.converter.MetricRatioConverter com.tencent.supersonic.headless.core.translator.converter.MetricRatioConverter
com.tencent.supersonic.headless.core.translator.optimizer.QueryOptimizer=\ com.tencent.supersonic.headless.core.translator.optimizer.QueryOptimizer=\

View File

@@ -58,6 +58,27 @@ public class MetricTest extends BaseTest {
assert actualResult.getQueryResults().size() == 1; assert actualResult.getQueryResults().size() == 1;
} }
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testDerivedMetricModel() throws Exception {
QueryResult actualResult = submitNewChat("超音数 人均访问次数", agent.getId());
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
assertQueryResult(expectedResult, actualResult);
assert actualResult.getQueryResults().size() == 1;
}
@Test @Test
public void testMetricFilter() throws Exception { public void testMetricFilter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数", agent.getId()); QueryResult actualResult = submitNewChat("alice的访问次数", agent.getId());
@@ -71,9 +92,9 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户"); SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId())); FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
expectedParseInfo.setDateInfo( expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
@@ -118,14 +139,14 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
list.add("alice"); list.add("alice");
list.add("lucy"); list.add("lucy");
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户"); SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
QueryFilter dimensionFilter = DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, QueryFilter dimensionFilter = DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list,
"用户", userElement.getId()); "用户", userElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.setDateInfo( expectedParseInfo.setDateInfo(
@@ -149,7 +170,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setAggType(MAX); expectedParseInfo.setAggType(MAX);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
expectedParseInfo.setDateInfo( expectedParseInfo.setDateInfo(
DataUtils.getDateConf(3, DateConf.DateMode.BETWEEN, DatePeriodEnum.DAY)); DataUtils.getDateConf(3, DateConf.DateMode.BETWEEN, DatePeriodEnum.DAY));
@@ -195,10 +216,10 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户"); SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId())); FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
expectedParseInfo.setDateInfo( expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay)); DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay));

View File

@@ -78,7 +78,7 @@ public class BaseTest extends BaseApplication {
queryStructReq.setQueryType(queryType); queryStructReq.setQueryType(queryType);
Aggregator aggregator = new Aggregator(); Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.SUM); aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setColumn("pv"); aggregator.setColumn("stay_hours");
queryStructReq.setAggregators(Arrays.asList(aggregator)); queryStructReq.setAggregators(Arrays.asList(aggregator));
if (CollectionUtils.isNotEmpty(groups)) { if (CollectionUtils.isNotEmpty(groups)) {
@@ -93,7 +93,7 @@ public class BaseTest extends BaseApplication {
List<Order> orders = new ArrayList<>(); List<Order> orders = new ArrayList<>();
Order order = new Order(); Order order = new Order();
order.setColumn("pv"); order.setColumn("stay_hours");
orders.add(order); orders.add(order);
queryStructReq.setOrders(orders); queryStructReq.setOrders(orders);
return queryStructReq; return queryStructReq;

View File

@@ -1,33 +0,0 @@
package com.tencent.supersonic.headless;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
import com.tencent.supersonic.headless.server.service.ModelService;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
public class ModelSchemaTest extends BaseTest {
@Autowired
private ModelService modelService;
@Test
void testGetUnAvailableItem() {
FieldRemovedReq fieldRemovedReq = new FieldRemovedReq();
fieldRemovedReq.setModelId(2L);
fieldRemovedReq.setFields(Lists.newArrayList("pv"));
UnAvailableItemResp unAvailableItemResp = modelService.getUnAvailableItem(fieldRemovedReq);
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 4L);
List<Long> actualUnAvailableMetricId =
unAvailableItemResp.getMetricResps().stream().map(MetricResp::getId)
.sorted(Comparator.naturalOrder()).collect(Collectors.toList());
Assertions.assertEquals(expectedUnAvailableMetricId, actualUnAvailableMetricId);
}
}

View File

@@ -40,7 +40,7 @@ public class QueryByMetricTest extends BaseTest {
public void testWithMetricAndDimensionNames() throws Exception { public void testWithMetricAndDimensionNames() throws Exception {
QueryMetricReq queryMetricReq = new QueryMetricReq(); QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数")); queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门")); queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门"));
queryMetricReq.getFilters() queryMetricReq.getFilters()
.add(Filter.builder().name("数据日期").operator(FilterOperatorEnum.MINOR_THAN_EQUALS) .add(Filter.builder().name("数据日期").operator(FilterOperatorEnum.MINOR_THAN_EQUALS)
.relation(Filter.Relation.FILTER).value(LocalDate.now().toString()) .relation(Filter.Relation.FILTER).value(LocalDate.now().toString())

View File

@@ -18,11 +18,11 @@ public class QueryBySqlTest extends BaseTest {
@Test @Test
public void testDetailQuery() throws Exception { public void testDetailQuery() throws Exception {
SemanticQueryResp semanticQueryResp = SemanticQueryResp semanticQueryResp =
queryBySql("SELECT 用户,访问次数 FROM 超音数PVUV统计 WHERE 用户='alice' "); queryBySql("SELECT 用户,访问次数 FROM 超音数PVUV统计 WHERE 用户='alice' ");
assertEquals(2, semanticQueryResp.getColumns().size()); assertEquals(2, semanticQueryResp.getColumns().size());
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0); QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);
assertEquals("用户", firstColumn.getName()); assertEquals("用户", firstColumn.getName());
QueryColumn secondColumn = semanticQueryResp.getColumns().get(1); QueryColumn secondColumn = semanticQueryResp.getColumns().get(1);
assertEquals("访问次数", secondColumn.getName()); assertEquals("访问次数", secondColumn.getName());
assertTrue(semanticQueryResp.getResultList().size() > 0); assertTrue(semanticQueryResp.getResultList().size() > 0);
@@ -106,10 +106,9 @@ public class QueryBySqlTest extends BaseTest {
@Test @Test
public void testAuthorization_sensitive_metric() throws Exception { public void testAuthorization_sensitive_metric() throws Exception {
User tom = DataUtils.getUserTom(); User tom = DataUtils.getUserAlice();
assertThrows(InvalidPermissionException.class, assertThrows(InvalidPermissionException.class,
() -> queryBySql("SELECT SUM(stay_hours) FROM 停留时长统计 WHERE department ='HR'", () -> queryBySql("SELECT pv_avg FROM 停留时长统计 WHERE department ='HR'", tom));
tom));
} }
@Test @Test
@@ -120,13 +119,4 @@ public class QueryBySqlTest extends BaseTest {
Assertions.assertTrue(semanticQueryResp.getResultList().size() > 0); Assertions.assertTrue(semanticQueryResp.getResultList().size() > 0);
} }
@Test
public void testAuthorization_row_permission() throws Exception {
User tom = DataUtils.getUserTom();
SemanticQueryResp semanticQueryResp =
queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'", tom);
Assertions.assertNotNull(semanticQueryResp.getQueryAuthorization().getMessage());
Assertions.assertTrue(semanticQueryResp.getSql().contains("user_name = 'tom'")
|| semanticQueryResp.getSql().contains("`user_name` = 'tom'"));
}
} }

View File

@@ -53,11 +53,9 @@ public class QueryByStructTest extends BaseTest {
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser()); semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(3, semanticQueryResp.getColumns().size()); assertEquals(3, semanticQueryResp.getColumns().size());
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0); QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);
assertEquals("用户", firstColumn.getName()); assertEquals("用户", firstColumn.getName());
QueryColumn secondColumn = semanticQueryResp.getColumns().get(1); QueryColumn secondColumn = semanticQueryResp.getColumns().get(1);
assertEquals("部门", secondColumn.getName()); assertEquals("部门", secondColumn.getName());
QueryColumn thirdColumn = semanticQueryResp.getColumns().get(2);
assertEquals("访问次数", thirdColumn.getName());
assertTrue(semanticQueryResp.getResultList().size() > 0); assertTrue(semanticQueryResp.getResultList().size() > 0);
} }
@@ -68,7 +66,7 @@ public class QueryByStructTest extends BaseTest {
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser()); semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(1, semanticQueryResp.getColumns().size()); assertEquals(1, semanticQueryResp.getColumns().size());
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0); QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
assertEquals("访问次数", queryColumn.getName()); assertEquals("停留时长", queryColumn.getName());
assertEquals(1, semanticQueryResp.getResultList().size()); assertEquals(1, semanticQueryResp.getResultList().size());
} }
@@ -81,7 +79,7 @@ public class QueryByStructTest extends BaseTest {
QueryColumn firstColumn = result.getColumns().get(0); QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1); QueryColumn secondColumn = result.getColumns().get(1);
assertEquals("部门", firstColumn.getName()); assertEquals("部门", firstColumn.getName());
assertEquals("访问次数", secondColumn.getName()); assertEquals("停留时长", secondColumn.getName());
assertNotNull(result.getResultList().size()); assertNotNull(result.getResultList().size());
} }
@@ -103,7 +101,7 @@ public class QueryByStructTest extends BaseTest {
QueryColumn firstColumn = result.getColumns().get(0); QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1); QueryColumn secondColumn = result.getColumns().get(1);
assertEquals("部门", firstColumn.getName()); assertEquals("部门", firstColumn.getName());
assertEquals("访问次数", secondColumn.getName()); assertEquals("停留时长", secondColumn.getName());
assertEquals(1, result.getResultList().size()); assertEquals(1, result.getResultList().size());
assertEquals("HR", result.getResultList().get(0).get("department").toString()); assertEquals("HR", result.getResultList().get(0).get("department").toString());
} }
@@ -122,7 +120,7 @@ public class QueryByStructTest extends BaseTest {
User tom = DataUtils.getUserTom(); User tom = DataUtils.getUserTom();
Aggregator aggregator = new Aggregator(); Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.SUM); aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setColumn("stay_hours"); aggregator.setColumn("pv_avg");
QueryStructReq queryStructReq = QueryStructReq queryStructReq =
buildQueryStructReq(Arrays.asList("department"), aggregator); buildQueryStructReq(Arrays.asList("department"), aggregator);
assertThrows(InvalidPermissionException.class, assertThrows(InvalidPermissionException.class,
@@ -134,7 +132,7 @@ public class QueryByStructTest extends BaseTest {
User tom = DataUtils.getUserTom(); User tom = DataUtils.getUserTom();
Aggregator aggregator = new Aggregator(); Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.SUM); aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setColumn("pv"); aggregator.setColumn("stay_hours");
QueryStructReq queryStructReq1 = QueryStructReq queryStructReq1 =
buildQueryStructReq(Collections.singletonList("department"), aggregator); buildQueryStructReq(Collections.singletonList("department"), aggregator);
SemanticQueryResp semanticQueryResp = semanticLayerService.queryByReq(queryStructReq1, tom); SemanticQueryResp semanticQueryResp = semanticLayerService.queryByReq(queryStructReq1, tom);

View File

@@ -46,6 +46,6 @@ public class TranslateTest extends BaseTest {
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department")); assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getQuerySQL().contains("pv")); assertTrue(explain.getQuerySQL().contains("stay_hours"));
} }
} }