[improvement][headless]Move discovery of query models from SemanticNode to SqlQueryParser.

[improvement][headless]Move discovery of query models from SemanticNode to `SqlQueryParser`.
This commit is contained in:
jerryjzhang
2024-12-22 20:29:51 +08:00
parent d8b8c4e6b9
commit 214d90772d
10 changed files with 180 additions and 182 deletions

View File

@@ -3,20 +3,14 @@ package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.enums.SchemaType;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@Data
@AllArgsConstructor
@@ -72,31 +66,4 @@ public class SemanticSchemaResp {
return names;
}
public Map<String, String> getNameToBizNameMap() {
// support fieldName and field alias to bizName
Map<String, String> dimensionResults = dimensions.stream().flatMap(
entry -> getPairStream(entry.getAlias(), entry.getName(), entry.getBizName()))
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
Map<String, String> metricResults = metrics.stream().flatMap(
entry -> getPairStream(entry.getAlias(), entry.getName(), entry.getBizName()))
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
dimensionResults.putAll(metricResults);
return dimensionResults;
}
private Stream<Pair<String, String>> getPairStream(String aliasStr, String name,
String bizName) {
Set<Pair<String, String>> elements = new HashSet<>();
elements.add(Pair.of(name, bizName));
if (StringUtils.isNotBlank(aliasStr)) {
List<String> aliasList = SchemaItem.getAliasList(aliasStr);
for (String alias : aliasList) {
elements.add(Pair.of(alias, bizName));
}
}
return elements.stream();
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.core.pojo;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ColumnOrder;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
@@ -9,6 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import lombok.Data;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@@ -19,22 +21,40 @@ import java.util.stream.Collectors;
@Data
public class OntologyQuery {
private Set<ModelResp> models = Sets.newHashSet();
private Set<MetricSchemaResp> metrics = Sets.newHashSet();
private Set<DimSchemaResp> dimensions = Sets.newHashSet();
private Map<String, ModelResp> modelMap = Maps.newHashMap();
private Map<String, Set<MetricSchemaResp>> metricMap = Maps.newHashMap();
private Map<String, Set<DimSchemaResp>> dimensionMap = Maps.newHashMap();
private Set<String> fields = Sets.newHashSet();
private Long limit;
private List<ColumnOrder> order;
private boolean nativeQuery = true;
private AggOption aggOption = AggOption.NATIVE;
public Set<MetricSchemaResp> getMetricsByModel(Long modelId) {
return metrics.stream().filter(m -> m.getModelId().equals(modelId))
.collect(Collectors.toSet());
public Set<ModelResp> getModels() {
return modelMap.values().stream().collect(Collectors.toSet());
}
public Set<DimSchemaResp> getDimensionsByModel(Long modelId) {
return dimensions.stream().filter(m -> m.getModelId().equals(modelId))
.collect(Collectors.toSet());
public Set<DimSchemaResp> getDimensions() {
Set<DimSchemaResp> dimensions = Sets.newHashSet();
dimensionMap.entrySet().forEach(entry -> {
dimensions.addAll(entry.getValue());
});
return dimensions;
}
public Set<MetricSchemaResp> getMetrics() {
Set<MetricSchemaResp> metrics = Sets.newHashSet();
metricMap.entrySet().forEach(entry -> {
metrics.addAll(entry.getValue());
});
return metrics;
}
public Set<MetricSchemaResp> getMetricsByModel(String modelName) {
return metricMap.get(modelName);
}
public Set<DimSchemaResp> getDimensionsByModel(String modelName) {
return dimensionMap.get(modelName);
}
}

View File

@@ -54,7 +54,9 @@ public class DimExpressionParser implements QueryParser {
for (DimSchemaResp queryDim : queryDimensions) {
queryDim.getFields().addAll(SqlSelectHelper.getFieldsFromExpr(queryDim.getExpr()));
queryFields.addAll(queryDim.getFields());
dim2Expr.put(queryDim.getBizName(), queryDim.getExpr());
if (!queryDim.getBizName().equals(queryDim.getExpr())) {
dim2Expr.put(queryDim.getBizName(), queryDim.getExpr());
}
}
return dim2Expr;

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.core.translator.parser;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
@@ -8,10 +7,9 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.pojo.Ontology;
import com.tencent.supersonic.headless.core.pojo.OntologyQuery;
@@ -20,10 +18,12 @@ import com.tencent.supersonic.headless.core.pojo.SqlQuery;
import com.tencent.supersonic.headless.core.utils.SqlGenerateUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* This parser rewrites S2SQL including conversion from metric/dimension name to bizName and build
@@ -40,11 +40,20 @@ public class SqlQueryParser implements QueryParser {
@Override
public void parse(QueryStatement queryStatement) throws Exception {
// build ontologyQuery
SqlQuery sqlQuery = queryStatement.getSqlQuery();
List<String> queryFields = SqlSelectHelper.getAllSelectFields(sqlQuery.getSql());
Ontology ontology = queryStatement.getOntology();
OntologyQuery ontologyQuery = buildOntologyQuery(ontology, queryFields);
queryStatement.setOntologyQuery(ontologyQuery);
AggOption sqlQueryAggOption = getAggOption(sqlQuery.getSql(), ontologyQuery.getMetrics());
ontologyQuery.setAggOption(sqlQueryAggOption);
convertNameToBizName(queryStatement);
rewriteOrderBy(queryStatement);
// fill sqlQuery
SqlQuery sqlQuery = queryStatement.getSqlQuery();
String tableName = SqlSelectHelper.getTableName(sqlQuery.getSql());
if (StringUtils.isEmpty(tableName)) {
return;
@@ -59,28 +68,10 @@ public class SqlQueryParser implements QueryParser {
sqlQuery.setWithAlias(false);
}
// build ontologyQuery
Ontology ontology = queryStatement.getOntology();
List<String> allQueryFields = SqlSelectHelper.getAllSelectFields(sqlQuery.getSql());
OntologyQuery ontologyQuery = new OntologyQuery();
queryStatement.setOntologyQuery(ontologyQuery);
List<MetricSchemaResp> queryMetrics = findQueryMetrics(ontology, allQueryFields);
ontologyQuery.getMetrics().addAll(queryMetrics);
List<DimSchemaResp> queryDimensions = findQueryDimensions(ontology, allQueryFields);
ontologyQuery.getDimensions().addAll(queryDimensions);
List<ModelResp> queryModels = findQueryModels(ontology, queryMetrics, queryDimensions);
ontologyQuery.getModels().addAll(queryModels);
AggOption sqlQueryAggOption = getAggOption(sqlQuery.getSql(), queryMetrics);
ontologyQuery.setAggOption(sqlQueryAggOption);
log.info("parse sqlQuery [{}] ", sqlQuery);
}
private AggOption getAggOption(String sql, List<MetricSchemaResp> metricSchemas) {
private AggOption getAggOption(String sql, Set<MetricSchemaResp> metricSchemas) {
if (SqlSelectFunctionHelper.hasAggregateFunction(sql)) {
return AggOption.AGGREGATION;
}
@@ -113,9 +104,36 @@ public class SqlQueryParser implements QueryParser {
return AggOption.DEFAULT;
}
private Map<String, String> getNameToBizNameMap(OntologyQuery query) {
// support fieldName and field alias to bizName
Map<String, String> dimensionResults = query.getDimensions().stream().flatMap(
entry -> getPairStream(entry.getAlias(), entry.getName(), entry.getBizName()))
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
Map<String, String> metricResults = query.getMetrics().stream().flatMap(
entry -> getPairStream(entry.getAlias(), entry.getName(), entry.getBizName()))
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
dimensionResults.putAll(metricResults);
return dimensionResults;
}
private Stream<Pair<String, String>> getPairStream(String aliasStr, String name,
String bizName) {
Set<Pair<String, String>> elements = new HashSet<>();
elements.add(Pair.of(name, bizName));
if (StringUtils.isNotBlank(aliasStr)) {
List<String> aliasList = SchemaItem.getAliasList(aliasStr);
for (String alias : aliasList) {
elements.add(Pair.of(alias, bizName));
}
}
return elements.stream();
}
private void convertNameToBizName(QueryStatement queryStatement) {
SemanticSchemaResp semanticSchema = queryStatement.getSemanticSchema();
Map<String, String> fieldNameToBizNameMap = semanticSchema.getNameToBizNameMap();
Map<String, String> fieldNameToBizNameMap =
getNameToBizNameMap(queryStatement.getOntologyQuery());
String sql = queryStatement.getSqlQuery().getSql();
log.debug("dataSetId:{},convert name to bizName before:{}", queryStatement.getDataSetId(),
sql);
@@ -136,57 +154,70 @@ public class SqlQueryParser implements QueryParser {
queryStatement.getSqlQuery().setSql(newSql);
}
public List<MetricSchemaResp> findQueryMetrics(Ontology ontology, List<String> bizNames) {
Map<String, MetricSchemaResp> metricLowerToNameMap = ontology.getMetrics().stream().collect(
Collectors.toMap(entry -> entry.getBizName().toLowerCase(), entry -> entry));
return bizNames.stream().map(String::toLowerCase)
.filter(entry -> metricLowerToNameMap.containsKey(entry))
.map(entry -> metricLowerToNameMap.get(entry)).collect(Collectors.toList());
}
private OntologyQuery buildOntologyQuery(Ontology ontology, List<String> queryFields) {
OntologyQuery ontologyQuery = new OntologyQuery();
Set<String> fields = Sets.newHashSet(queryFields);
public List<DimSchemaResp> findQueryDimensions(Ontology ontology, List<String> bizNames) {
Map<String, DimSchemaResp> dimLowerToNameMap = ontology.getDimensions().stream().collect(
Collectors.toMap(entry -> entry.getBizName().toLowerCase(), entry -> entry));
return bizNames.stream().map(String::toLowerCase)
.filter(entry -> dimLowerToNameMap.containsKey(entry))
.map(entry -> dimLowerToNameMap.get(entry)).collect(Collectors.toList());
}
public List<ModelResp> findQueryModels(Ontology ontology, List<MetricSchemaResp> queryMetrics,
List<DimSchemaResp> queryDimensions) {
// first, sort models based on the number of query metrics
Map<String, Integer> modelMetricCount = Maps.newHashMap();
queryMetrics.forEach(m -> {
if (!modelMetricCount.containsKey(m.getModelBizName())) {
modelMetricCount.put(m.getModelBizName(), 1);
} else {
int count = modelMetricCount.get(m.getModelBizName());
modelMetricCount.put(m.getModelBizName(), count + 1);
}
// find belonging model for every querying metrics
ontology.getMetricMap().entrySet().forEach(entry -> {
String modelName = entry.getKey();
entry.getValue().forEach(m -> {
if (fields.contains(m.getName()) || fields.contains(m.getBizName())) {
if (!ontologyQuery.getMetricMap().containsKey(modelName)) {
ontologyQuery.getMetricMap().put(modelName, Sets.newHashSet());
}
ontologyQuery.getModelMap().put(modelName,
ontology.getModelMap().get(modelName));
ontologyQuery.getMetricMap().get(modelName).add(m);
fields.remove(m.getName());
fields.remove(m.getBizName());
}
});
});
List<String> metricsDataModels = modelMetricCount.entrySet().stream()
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).map(e -> e.getKey())
.collect(Collectors.toList());
// second, sort models based on the number of query dimensions
Map<String, Integer> modelDimCount = Maps.newHashMap();
queryDimensions.forEach(m -> {
if (!modelDimCount.containsKey(m.getModelBizName())) {
modelDimCount.put(m.getModelBizName(), 1);
} else {
int count = modelDimCount.get(m.getModelBizName());
modelDimCount.put(m.getModelBizName(), count + 1);
}
});
List<String> dimDataModels = modelDimCount.entrySet().stream()
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).map(e -> e.getKey())
.collect(Collectors.toList());
// first try to find all querying dimensions in the models with querying metrics.
ontology.getDimensionMap().entrySet().stream()
.filter(entry -> ontologyQuery.getMetricMap().containsKey(entry.getKey()))
.forEach(entry -> {
String modelName = entry.getKey();
entry.getValue().forEach(d -> {
if (fields.contains(d.getName()) || fields.contains(d.getBizName())) {
if (!ontologyQuery.getDimensionMap().containsKey(entry.getKey())) {
ontologyQuery.getDimensionMap().put(entry.getKey(),
Sets.newHashSet());
}
ontologyQuery.getModelMap().put(modelName,
ontology.getModelMap().get(modelName));
ontologyQuery.getDimensionMap().get(entry.getKey()).add(d);
fields.remove(d.getName());
fields.remove(d.getBizName());
}
});
});
Set<String> dataModelNames = Sets.newLinkedHashSet();
dataModelNames.addAll(dimDataModels);
dataModelNames.addAll(metricsDataModels);
return dataModelNames.stream().map(bizName -> ontology.getModelMap().get(bizName))
.collect(Collectors.toList());
// if there are still fields not found belonging models, try to find in the models without
// querying metrics.
if (!fields.isEmpty()) {
ontology.getDimensionMap().entrySet().forEach(entry -> {
String modelName = entry.getKey();
if (!ontologyQuery.getDimensionMap().containsKey(modelName)) {
entry.getValue().forEach(d -> {
if (fields.contains(d.getName()) || fields.contains(d.getBizName())) {
if (!ontologyQuery.getDimensionMap().containsKey(modelName)) {
ontologyQuery.getDimensionMap().put(modelName, Sets.newHashSet());
}
ontologyQuery.getModelMap().put(modelName,
ontology.getModelMap().get(modelName));
ontologyQuery.getDimensionMap().get(modelName).add(d);
fields.remove(d.getName());
fields.remove(d.getBizName());
}
});
}
});
}
return ontologyQuery;
}
}

View File

@@ -91,9 +91,9 @@ public class SqlBuilder {
for (int i = 0; i < dataModels.size(); i++) {
final ModelResp dataModel = dataModels.get(i);
final Set<DimSchemaResp> queryDimensions =
ontologyQuery.getDimensionsByModel(dataModel.getId());
ontologyQuery.getDimensionsByModel(dataModel.getName());
final Set<MetricSchemaResp> queryMetrics =
ontologyQuery.getMetricsByModel(dataModel.getId());
ontologyQuery.getMetricsByModel(dataModel.getName());
List<String> primary = new ArrayList<>();
for (Identify identify : dataModel.getIdentifiers()) {
@@ -248,8 +248,12 @@ public class SqlBuilder {
TableView tableView = new TableView();
EngineType engineType = EngineType.fromString(schema.getOntology().getDatabase().getType());
Set<String> queryFields = tableView.getFields();
queryMetrics.stream().forEach(m -> queryFields.addAll(m.getFields()));
queryDimensions.stream().forEach(d -> queryFields.addAll(d.getFields()));
if (Objects.nonNull(queryMetrics)) {
queryMetrics.stream().forEach(m -> queryFields.addAll(m.getFields()));
}
if (Objects.nonNull(queryDimensions)) {
queryDimensions.stream().forEach(d -> queryFields.addAll(d.getFields()));
}
try {
for (String field : queryFields) {

View File

@@ -697,31 +697,31 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
queryMetricReq.setDateInfo(null);
}
// 4. set groups
List<String> dimensionBizNames = dimensionResps.stream()
List<String> dimensionNames = dimensionResps.stream()
.filter(entry -> modelCluster.getModelIds().contains(entry.getModelId()))
.filter(entry -> queryMetricReq.getDimensionNames().contains(entry.getName())
|| queryMetricReq.getDimensionNames().contains(entry.getBizName())
|| queryMetricReq.getDimensionIds().contains(entry.getId()))
.map(SchemaItem::getBizName).collect(Collectors.toList());
.map(SchemaItem::getName).collect(Collectors.toList());
QueryStructReq queryStructReq = new QueryStructReq();
DateConf dateInfo = queryMetricReq.getDateInfo();
if (Objects.nonNull(dateInfo) && dateInfo.isGroupByDate()) {
queryStructReq.getGroups().add(dateInfo.getDateField());
}
if (!CollectionUtils.isEmpty(dimensionBizNames)) {
queryStructReq.getGroups().addAll(dimensionBizNames);
if (!CollectionUtils.isEmpty(dimensionNames)) {
queryStructReq.getGroups().addAll(dimensionNames);
}
// 5. set aggregators
List<String> metricBizNames = metricResps.stream()
List<String> metricNames = metricResps.stream()
.filter(entry -> modelCluster.getModelIds().contains(entry.getModelId()))
.map(SchemaItem::getBizName).collect(Collectors.toList());
if (CollectionUtils.isEmpty(metricBizNames)) {
.map(SchemaItem::getName).collect(Collectors.toList());
if (CollectionUtils.isEmpty(metricNames)) {
throw new IllegalArgumentException(
"Invalid input parameters, unable to obtain valid metrics");
}
List<Aggregator> aggregators = new ArrayList<>();
for (String metricBizName : metricBizNames) {
for (String metricBizName : metricNames) {
Aggregator aggregator = new Aggregator();
aggregator.setColumn(metricBizName);
aggregators.add(aggregator);

View File

@@ -131,7 +131,6 @@ public class MetricTest extends BaseTest {
assertQueryResult(expectedResult, actualResult);
assert actualResult.getQueryResults().size() == 6;
assert actualResult.getQuerySql().contains("s2_pv_uv_statis");
assert actualResult.getQuerySql().contains("s2_user_department");
}
@Test
@@ -237,6 +236,7 @@ public class MetricTest extends BaseTest {
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testMetricGroupBySum() throws Exception {
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", agent.getId());
QueryResult expectedResult = new QueryResult();

View File

@@ -22,19 +22,6 @@ public class QueryByMetricTest extends BaseTest {
@Autowired
protected MetricService metricService;
@Test
public void testWithMetricAndDimensionBizNames() throws Exception {
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
.value(LocalDate.now().toString()).build());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testWithMetricAndDimensionNames() throws Exception {
@@ -51,21 +38,23 @@ public class QueryByMetricTest extends BaseTest {
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testWithDomainId() throws Exception {
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
.value(LocalDate.now().toString()).build());
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
queryMetricReq.setDimensionNames(Arrays.asList("用户名", "部门"));
queryMetricReq.getFilters()
.add(Filter.builder().name("数据日期").operator(FilterOperatorEnum.MINOR_THAN_EQUALS)
.relation(Filter.Relation.FILTER).value(LocalDate.now().toString())
.build());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
queryMetricReq.setDomainId(2L);
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
queryMetricReq.setDimensionNames(Arrays.asList("用户名", "部门"));
assertThrows(IllegalArgumentException.class,
() -> queryByMetric(queryMetricReq, User.getDefaultUser()));
}
@@ -76,9 +65,10 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
.value(LocalDate.now().toString()).build());
queryMetricReq.getFilters()
.add(Filter.builder().name("数据日期").operator(FilterOperatorEnum.MINOR_THAN_EQUALS)
.relation(Filter.Relation.FILTER).value(LocalDate.now().toString())
.build());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());

View File

@@ -87,17 +87,6 @@ public class QueryBySqlTest extends BaseTest {
assertTrue(result2.isUseCache());
}
@Test
public void testBizNameQuery() throws Exception {
SemanticQueryResp result1 =
queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'");
SemanticQueryResp result2 = queryBySql("SELECT SUM(访问次数) FROM 超音数PVUV统计 WHERE 部门 ='HR'");
assertEquals(1, result1.getColumns().size());
assertEquals(1, result2.getColumns().size());
assertEquals(result1.getColumns().get(0), result2.getColumns().get(0));
assertEquals(result1.getResultList(), result2.getResultList());
}
@Test
public void testAuthorization_model() {
User alice = DataUtils.getUserAlice();
@@ -116,8 +105,7 @@ public class QueryBySqlTest extends BaseTest {
@Test
public void testAuthorization_sensitive_metric_jack() throws Exception {
User jack = DataUtils.getUserJack();
SemanticQueryResp semanticQueryResp =
queryBySql("SELECT SUM(stay_hours) FROM 停留时长统计", jack);
SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(停留时长) FROM 停留时长统计", jack);
Assertions.assertTrue(semanticQueryResp.getResultList().size() > 0);
}

View File

@@ -14,11 +14,7 @@ import com.tencent.supersonic.headless.core.cache.QueryCache;
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.api.*;
import java.util.ArrayList;
import java.util.Arrays;
@@ -32,13 +28,14 @@ import static org.junit.Assert.assertTrue;
@Slf4j
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@Disabled
public class QueryByStructTest extends BaseTest {
@Test
@Order(0)
public void testCacheQuery() {
QueryStructReq queryStructReq1 = buildQueryStructReq(Arrays.asList("department"));
QueryStructReq queryStructReq2 = buildQueryStructReq(Arrays.asList("department"));
QueryStructReq queryStructReq1 = buildQueryStructReq(Arrays.asList("部门"));
QueryStructReq queryStructReq2 = buildQueryStructReq(Arrays.asList("部门"));
QueryCache queryCache = ComponentFactory.getQueryCache();
String cacheKey1 = queryCache.getCacheKey(queryStructReq1);
String cacheKey2 = queryCache.getCacheKey(queryStructReq2);
@@ -48,7 +45,7 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testDetailQuery() throws Exception {
QueryStructReq queryStructReq =
buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL);
buildQueryStructReq(Arrays.asList("用户名", "部门"), QueryType.DETAIL);
SemanticQueryResp semanticQueryResp =
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(3, semanticQueryResp.getColumns().size());
@@ -72,7 +69,7 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testGroupByQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("部门"));
SemanticQueryResp result =
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(2, result.getColumns().size());
@@ -85,7 +82,7 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testFilterQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("部门"));
List<Filter> dimensionFilters = new ArrayList<>();
Filter filter = new Filter();
filter.setName("部门");
@@ -103,14 +100,14 @@ public class QueryByStructTest extends BaseTest {
assertEquals("部门", firstColumn.getName());
assertEquals("停留时长", secondColumn.getName());
assertEquals(1, result.getResultList().size());
assertEquals("HR", result.getResultList().get(0).get("department").toString());
assertEquals("HR", result.getResultList().get(0).get("部门").toString());
}
@Test
public void testAuthorization_model() {
User alice = DataUtils.getUserAlice();
setDomainNotOpenToAll();
QueryStructReq queryStructReq1 = buildQueryStructReq(Arrays.asList("department"));
QueryStructReq queryStructReq1 = buildQueryStructReq(Arrays.asList("部门"));
assertThrows(InvalidPermissionException.class,
() -> semanticLayerService.queryByReq(queryStructReq1, alice));
}
@@ -120,9 +117,8 @@ public class QueryByStructTest extends BaseTest {
User tom = DataUtils.getUserTom();
Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setColumn("pv_avg");
QueryStructReq queryStructReq =
buildQueryStructReq(Arrays.asList("department"), aggregator);
aggregator.setColumn("人均访问次数");
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("部门"), aggregator);
assertThrows(InvalidPermissionException.class,
() -> semanticLayerService.queryByReq(queryStructReq, tom));
}
@@ -132,11 +128,11 @@ public class QueryByStructTest extends BaseTest {
User tom = DataUtils.getUserTom();
Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setColumn("stay_hours");
aggregator.setColumn("停留时长");
QueryStructReq queryStructReq1 =
buildQueryStructReq(Collections.singletonList("department"), aggregator);
buildQueryStructReq(Collections.singletonList("部门"), aggregator);
SemanticQueryResp semanticQueryResp = semanticLayerService.queryByReq(queryStructReq1, tom);
Assertions.assertNotNull(semanticQueryResp.getQueryAuthorization().getMessage());
Assertions.assertTrue(semanticQueryResp.getSql().contains("user_name = 'tom'"));
Assertions.assertTrue(semanticQueryResp.getSql().contains("用户名 = 'tom'"));
}
}