(improvement)(auth) When checking auth, only the models involved in the query are considered rather than the models included in the data set. #1625 (#1731)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-09-29 00:33:10 +08:00
committed by GitHub
parent 3a11ccb6e9
commit 47df22d1a0
4 changed files with 94 additions and 52 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.server.aspect;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
@@ -81,7 +82,7 @@ public class S2DataPermissionAspect {
}
SemanticSchemaResp semanticSchemaResp = getSemanticSchemaResp(queryReq);
List<Long> modelIds = getModelIds(semanticSchemaResp);
Set<Long> modelIds = getModelIdInQuery(queryReq, semanticSchemaResp);
// 2. determine whether admin of the model
if (checkModelAdmin(user, modelIds)) {
@@ -111,7 +112,7 @@ public class S2DataPermissionAspect {
private void checkColPermission(
SemanticQueryReq semanticQueryReq,
AuthorizedResourceResp authorizedResource,
List<Long> modelIds,
Set<Long> modelIds,
SemanticSchemaResp semanticSchemaResp) {
// get high sensitive fields in query
Set<String> bizNamesInQueryReq = getBizNameInQueryReq(semanticQueryReq, semanticSchemaResp);
@@ -132,13 +133,26 @@ public class S2DataPermissionAspect {
if (!CollectionUtils.isEmpty(sensitiveBizNameInQuery)) {
Set<String> sensitiveResNames =
semanticSchemaResp.getNameFromBizNames(sensitiveBizNameInQuery);
List<String> modelAdmin = modelService.getModelAdmin(modelIds.get(0));
List<String> modelAdmin = modelService.getModelAdmin(modelIds.iterator().next());
String message =
String.format("存在以下敏感资源:%s您暂无权限请联系管理员%s申请", sensitiveResNames, modelAdmin);
throw new InvalidPermissionException(message);
}
}
private Set<Long> getModelIdInQuery(
SemanticQueryReq semanticQueryReq, SemanticSchemaResp semanticSchemaResp) {
if (semanticQueryReq instanceof QuerySqlReq) {
QuerySqlReq querySqlReq = (QuerySqlReq) semanticQueryReq;
return queryStructUtils.getModelIdFromSql(querySqlReq, semanticSchemaResp);
}
if (semanticQueryReq instanceof QueryStructReq) {
QueryStructReq queryStructReq = (QueryStructReq) semanticQueryReq;
return queryStructUtils.getModelIdsFromStruct(queryStructReq, semanticSchemaResp);
}
return Sets.newHashSet();
}
private void checkRowPermission(
SemanticQueryReq queryReq, AuthorizedResourceResp authorizedResource) {
if (queryReq instanceof QuerySqlReq) {
@@ -167,12 +181,6 @@ public class S2DataPermissionAspect {
return schemaService.fetchSemanticSchema(filter);
}
private List<Long> getModelIds(SemanticSchemaResp semanticSchemaResp) {
return semanticSchemaResp.getModelResps().stream()
.map(ModelResp::getId)
.collect(Collectors.toList());
}
private void doRowPermission(
QuerySqlReq querySqlReq, AuthorizedResourceResp authorizedResource) {
log.debug("start doRowPermission logic");
@@ -246,7 +254,7 @@ public class S2DataPermissionAspect {
}
}
public boolean checkModelAdmin(User user, List<Long> modelIds) {
public boolean checkModelAdmin(User user, Set<Long> modelIds) {
List<ModelResp> modelListAdmin =
modelService.getModelListWithAuth(user, null, AuthType.ADMIN);
if (CollectionUtils.isEmpty(modelListAdmin)) {
@@ -258,7 +266,7 @@ public class S2DataPermissionAspect {
}
}
public void checkModelVisible(User user, List<Long> modelIds) {
public void checkModelVisible(User user, Set<Long> modelIds) {
List<Long> modelListVisible =
modelService.getModelListWithAuth(user, null, AuthType.VISIBLE).stream()
.map(ModelResp::getId)
@@ -303,9 +311,9 @@ public class S2DataPermissionAspect {
return highSensitiveCols;
}
public AuthorizedResourceResp getAuthorizedResource(User user, List<Long> modelIds) {
public AuthorizedResourceResp getAuthorizedResource(User user, Set<Long> modelIds) {
QueryAuthResReq queryAuthResReq = new QueryAuthResReq();
queryAuthResReq.setModelIds(modelIds);
queryAuthResReq.setModelIds(new ArrayList<>(modelIds));
AuthorizedResourceResp authorizedResource = fetchAuthRes(queryAuthResReq, user);
log.info(
"user:{}, domainId:{}, after queryAuthorizedResources:{}",
@@ -321,17 +329,17 @@ public class S2DataPermissionAspect {
}
public void addHint(
List<Long> modelIds,
Set<Long> modelIds,
SemanticQueryResp queryResultWithColumns,
AuthorizedResourceResp authorizedResource) {
List<DimensionFilter> filters = authorizedResource.getFilters();
if (CollectionUtils.isEmpty(filters)) {
return;
}
List<String> admins = modelService.getModelAdmin(modelIds.get(0));
List<String> admins = modelService.getModelAdmin(modelIds.iterator().next());
if (!CollectionUtils.isEmpty(filters)) {
ModelResp modelResp = modelService.getModel(modelIds.get(0));
ModelResp modelResp = modelService.getModel(modelIds.iterator().next());
List<String> exprList = new ArrayList<>();
List<String> descList = new ArrayList<>();
filters.stream()

View File

@@ -214,7 +214,9 @@ public class ModelServiceImpl implements ModelService {
String message = String.format("模型英文名[%s]需要为下划线字母数字组合, 请修改", modelReq.getBizName());
throw new InvalidArgumentException(message);
}
if (modelReq.getModelDetail() == null) {
return;
}
List<Dim> dims = modelReq.getModelDetail().getDimensions();
List<Measure> measures = modelReq.getModelDetail().getMeasures();
List<Identify> identifies = modelReq.getModelDetail().getIdentifiers();

View File

@@ -1,5 +1,7 @@
package com.tencent.supersonic.headless.server.utils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.Aggregator;
@@ -15,10 +17,8 @@ import com.tencent.supersonic.headless.api.pojo.MetaFilter;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.server.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
@@ -47,14 +47,8 @@ import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
@Component
public class QueryStructUtils {
public static Set<String> internalTimeCols =
public static Set<String> internalCols =
new HashSet<>(Arrays.asList("dayno", "sys_imp_date", "sys_imp_week", "sys_imp_month"));
public static Set<String> internalCols;
static {
internalCols = new HashSet<>(Arrays.asList("plat_sys_var"));
internalCols.addAll(internalTimeCols);
}
private final DateModeUtils dateModeUtils;
private final SqlFilterUtils sqlFilterUtils;
@@ -127,33 +121,71 @@ public class QueryStructUtils {
return new HashSet<>(SqlSelectHelper.getAllSelectFields(querySqlReq.getSql()));
}
public Set<String> getBizNameFromSql(
public Set<Long> getModelIdsFromStruct(
QueryStructReq queryStructReq, SemanticSchemaResp semanticSchemaResp) {
Set<Long> modelIds = Sets.newHashSet();
Set<String> bizNameFromStruct = getBizNameFromStruct(queryStructReq);
modelIds.addAll(
semanticSchemaResp.getMetrics().stream()
.filter(metric -> bizNameFromStruct.contains(metric.getBizName()))
.map(MetricResp::getModelId)
.collect(Collectors.toSet()));
modelIds.addAll(
semanticSchemaResp.getDimensions().stream()
.filter(dimension -> bizNameFromStruct.contains(dimension.getBizName()))
.map(DimensionResp::getModelId)
.collect(Collectors.toList()));
return modelIds;
}
private List<MetricResp> getMetricsFromSql(
QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) {
Set<String> resNameSet = getResName(querySqlReq);
Set<String> resNameEnSet = new HashSet<>();
if (semanticSchemaResp != null) {
List<MetricSchemaResp> metrics = semanticSchemaResp.getMetrics();
List<DimSchemaResp> dimensions = semanticSchemaResp.getDimensions();
metrics.stream()
.forEach(
o -> {
if (resNameSet.contains(o.getName())
|| resNameSet.contains(o.getBizName())) {
resNameEnSet.add(o.getBizName());
return semanticSchemaResp.getMetrics().stream()
.filter(
m ->
resNameSet.contains(m.getName())
|| resNameSet.contains(m.getBizName()))
.collect(Collectors.toList());
}
});
dimensions.stream()
.forEach(
o -> {
if (resNameSet.contains(o.getName())
|| resNameSet.contains(o.getBizName())) {
resNameEnSet.add(o.getBizName());
return Lists.newArrayList();
}
});
private List<DimensionResp> getDimensionsFromSql(
QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) {
Set<String> resNameSet = getResName(querySqlReq);
if (semanticSchemaResp != null) {
return semanticSchemaResp.getDimensions().stream()
.filter(
m ->
resNameSet.contains(m.getName())
|| resNameSet.contains(m.getBizName()))
.collect(Collectors.toList());
}
return resNameEnSet.stream()
.filter(res -> !internalCols.contains(res))
.collect(Collectors.toSet());
return Lists.newArrayList();
}
public Set<Long> getModelIdFromSql(
QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) {
Set<Long> modelIds = Sets.newHashSet();
List<DimensionResp> dimensions = getDimensionsFromSql(querySqlReq, semanticSchemaResp);
List<MetricResp> metrics = getMetricsFromSql(querySqlReq, semanticSchemaResp);
modelIds.addAll(
dimensions.stream().map(DimensionResp::getModelId).collect(Collectors.toList()));
modelIds.addAll(metrics.stream().map(MetricResp::getModelId).collect(Collectors.toList()));
return modelIds;
}
public Set<String> getBizNameFromSql(
QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) {
Set<String> bizNames = Sets.newHashSet();
List<DimensionResp> dimensions = getDimensionsFromSql(querySqlReq, semanticSchemaResp);
List<MetricResp> metrics = getMetricsFromSql(querySqlReq, semanticSchemaResp);
bizNames.addAll(
dimensions.stream().map(DimensionResp::getBizName).collect(Collectors.toList()));
bizNames.addAll(metrics.stream().map(MetricResp::getBizName).collect(Collectors.toList()));
return bizNames;
}
public ItemDateResp getItemDateResp(QueryStructReq queryStructCmd) {

View File

@@ -101,7 +101,7 @@ public class S2VisitsDemo extends S2BaseDemo {
// create data set
DataSetResp s2DataSet = addDataSet(s2Domain);
addAuthGroup_1(stayTimeModel);
addAuthGroup_2(stayTimeModel);
addAuthGroup_2(pvUvModel);
// create terms and plugin
addTerm(s2Domain);
@@ -513,9 +513,9 @@ public class S2VisitsDemo extends S2BaseDemo {
authService.addOrUpdateAuthGroup(authGroupReq);
}
public void addAuthGroup_2(ModelResp stayTimeModel) {
public void addAuthGroup_2(ModelResp pvuvModel) {
AuthGroup authGroupReq = new AuthGroup();
authGroupReq.setModelId(stayTimeModel.getId());
authGroupReq.setModelId(pvuvModel.getId());
authGroupReq.setName("tom_row_permission");
List<AuthRule> authRules = new ArrayList<>();