(improvement)(semantic) perfect dsl permission (#161)

This commit is contained in:
mainmain
2023-09-27 21:01:44 +08:00
committed by GitHub
parent 6047c787b3
commit e688422ec3
10 changed files with 233 additions and 21 deletions

View File

@@ -14,6 +14,7 @@ public class QueryDataReq {
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<QueryFilter> dimensionFilters = new HashSet<>();
private Set<QueryFilter> metricFilters = new HashSet<>();
private DateConf dateInfo;
private Long queryId = 7L;
private Integer parseId = 2;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
@@ -38,6 +39,8 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
import java.util.Map;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;
@@ -295,27 +298,73 @@ public class QueryServiceImpl implements QueryService {
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
}
if (parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)
&& CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
&& (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())
|| CollectionUtils.isNotEmpty(queryData.getMetricFilters()))) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
String correctorSql = llmResp.getCorrectorSql();
log.info("correctorSql before replacing:{}", correctorSql);
List<FilterExpression> filterExpressionList = SqlParserSelectHelper.getFilterExpression(correctorSql);
for (QueryFilter dslQueryFilter : queryData.getDimensionFilters()) {
for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) {
if (dslQueryFilter.getBizName().equals(queryFilter.getBizName())) {
Map<String, String> map = new HashMap<>();
map.put(queryFilter.getValue().toString(), dslQueryFilter.getValue().toString());
filedNameToValueMap.put(dslQueryFilter.getBizName(), map);
Map<String, String> map = new HashMap<>();
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName().equals(dslQueryFilter.getBizName())
&& dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString());
break;
}
}
filedNameToValueMap.put(dslQueryFilter.getBizName(), map);
}
for (QueryFilter dslQueryFilter : queryData.getMetricFilters()) {
Map<String, String> map = new HashMap<>();
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName().equals(dslQueryFilter.getBizName())
&& dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString());
break;
}
}
filedNameToValueMap.put(dslQueryFilter.getBizName(), map);
}
String dateField = "sys_imp_date";
if (Objects.nonNull(queryData.getDateInfo())) {
Map<String, String> map = new HashMap<>();
List<String> dateFields = Lists.newArrayList("dayno", "sys_imp_date", "sys_imp_week", "sys_imp_month");
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
for (FilterExpression filterExpression : filterExpressionList) {
if (dateFields.contains(filterExpression.getFieldName())) {
dateField = filterExpression.getFieldName();
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
break;
}
}
} else {
for (FilterExpression filterExpression : filterExpressionList) {
if (dateFields.contains(filterExpression.getFieldName())) {
dateField = filterExpression.getFieldName();
if (filterExpression.getOperator().equals(">=")
|| filterExpression.getOperator().equals(">")) {
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
}
if (filterExpression.getOperator().equals("<=")
|| filterExpression.getOperator().equals("<")) {
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getEndDate());
}
}
}
}
filedNameToValueMap.put(dateField, map);
}
log.info("filedNameToValueMap:{}", filedNameToValueMap);
correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap);

View File

@@ -2,10 +2,19 @@ package com.tencent.supersonic.common.util.jsqlparser;
import java.util.Map;
import java.util.Objects;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.DoubleValue;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
@@ -51,4 +60,64 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
}
}
}
public void visit(GreaterThan expr) {
replaceComparisonExpression(expr);
}
public void visit(GreaterThanEquals expr) {
replaceComparisonExpression(expr);
}
public void visit(MinorThanEquals expr) {
replaceComparisonExpression(expr);
}
public void visit(MinorThan expr) {
replaceComparisonExpression(expr);
}
public <T extends Expression> void replaceComparisonExpression(T expression) {
if ((expression instanceof GreaterThanEquals) || (expression instanceof GreaterThan)
|| (expression instanceof MinorThanEquals) || (expression instanceof MinorThan)) {
Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression();
Expression rightExpression = ((ComparisonOperator) expression).getRightExpression();
if (!(leftExpression instanceof Column)) {
return;
}
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
return;
}
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
return;
}
Column leftColumnName = (Column) leftExpression;
String columnName = leftColumnName.getColumnName();
if (StringUtils.isEmpty(columnName)) {
return;
}
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
return;
}
for (String oriValue : valueMap.keySet()) {
String replaceValue = valueMap.get(oriValue);
if (StringUtils.isNotEmpty(replaceValue)) {
if (rightExpression instanceof LongValue) {
LongValue rightStringValue = (LongValue) rightExpression;
rightStringValue.setValue(Long.parseLong(replaceValue));
}
if (rightExpression instanceof DoubleValue) {
DoubleValue rightStringValue = (DoubleValue) rightExpression;
rightStringValue.setValue(Double.parseDouble(replaceValue));
}
if (rightExpression instanceof StringValue) {
StringValue rightStringValue = (StringValue) rightExpression;
rightStringValue.setValue(replaceValue);
}
}
}
}
}
}

View File

@@ -1,3 +1,40 @@
tagore _3_8 9000
nazrul _3_8 9000
民间 _3_8 9000
现代 _3_8 9000
蓝调 _3_8 9000
流行 _3_8 9000
孟加拉国 _3_10 9000
锡尔赫特、吉大港、库斯蒂亚 _3_10 9000
加拿大 _3_10 9000
美国 _3_10 9000
Shrikanta _3_11 9000
Prity _3_11 9000
Farida _3_11 9000
Topu _3_11 9000
Enrique _3_11 9000
Michel _3_11 9000
孟加拉国 _3_12 9000
印度 _3_12 9000
美国 _3_12 9000
英国 _3_12 9000
男性 _3_13 9000
女性 _3_13 9000
mp4 _3_19 9000
mp3 _3_19 9000
Tumi#长袍#尼罗布 _3_20 9000
舒克诺#帕塔尔#努普尔#帕埃 _3_20 9000
阿米·奥帕尔·霍伊 _3_20 9000
我的爱 _3_20 9000
打败它 _3_20 9000
阿杰伊阿卡什 _3_20 9000
孟加拉国 _3_22 9000
印度 _3_22 9000
美国 _3_22 9000
英国 _3_22 9000
孟加拉语 _3_26 9000
英文 _3_26 9000
=======
孟加拉国 _3_8 9000
锡尔赫特、吉大港、库斯蒂亚 _3_8 9000
加拿大 _3_8 9000

View File

@@ -59,6 +59,12 @@ public class QueryController {
return queryService.queryByStructWithAuth(queryStructReq, user);
}
@PostMapping("/queryStatement")
public Object queryStatement(@RequestBody QueryStatement queryStatement) throws Exception {
Object result = queryService.queryByQueryStatement(queryStatement);
return result;
}
@PostMapping("/struct/parse")
public SqlParserResp parseByStruct(@RequestBody ParseSqlReq parseSqlReq,
HttpServletRequest request,

View File

@@ -155,9 +155,11 @@ public class AuthCommonService {
boolean doDesensitization = false;
for (QueryColumn queryColumn : columns) {
if (need2Apply.contains(queryColumn.getNameEn())) {
doDesensitization = true;
break;
for (String sensitiveCol : need2Apply) {
if (queryColumn.getNameEn().contains(sensitiveCol)) {
doDesensitization = true;
break;
}
}
}
if (!doDesensitization) {
@@ -192,8 +194,15 @@ public class AuthCommonService {
Map<String, Object> row = result.get(i);
Map<String, Object> newRow = new HashMap<>();
for (String col : row.keySet()) {
if (need2Apply.contains(col)) {
newRow.put(col, "****");
boolean sensitive = false;
for (String sensitiveCol : need2Apply) {
if (col.contains(sensitiveCol)) {
sensitive = true;
break;
}
}
if (sensitive) {
newRow.put(col, "******");
} else {
newRow.put(col, row.get(col));
}

View File

@@ -10,6 +10,8 @@ import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.semantic.api.query.response.ItemUseResp;
import com.tencent.supersonic.semantic.query.persistence.pojo.QueryStatement;
import java.util.List;
public interface QueryService {
@@ -25,6 +27,8 @@ public interface QueryService {
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
Object queryByQueryStatement(QueryStatement queryStatement);
List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend);
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;

View File

@@ -77,11 +77,17 @@ public class QueryServiceImpl implements QueryService {
} catch (Exception e) {
log.info("convertToQueryStatement has a exception:{}", e.toString());
}
log.info("queryStatement:{}", queryStatement);
QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement);
statUtils.statInfo2DbAsync(TaskStatusEnum.SUCCESS);
return results;
}
public Object queryByQueryStatement(QueryStatement queryStatement) {
QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement);
return results;
}
private QueryStatement convertToQueryStatement(QueryDslReq querySqlCmd, User user) throws Exception {
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
List<Long> modelIds = new ArrayList<>();

View File

@@ -84,7 +84,7 @@ public class DslDataAspect {
authCommonService.doModelVisible(user, modelId);
// 3. fetch data permission meta information
Set<String> res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(queryDslReq);
Set<String> res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(queryDslReq, user);
log.info("modelId:{}, res4Privilege:{}", modelId, res4Privilege);
Set<String> sensitiveResByModel = authCommonService.getHighSensitiveColsByModelId(modelId);
@@ -117,6 +117,7 @@ public class DslDataAspect {
// 6.if the column has no permission, hit *
Set<String> need2Apply = sensitiveResReq.stream().filter(req -> !resAuthSet.contains(req))
.collect(Collectors.toSet());
log.info("need2Apply:{},sensitiveResReq:{},resAuthSet:{}", need2Apply, sensitiveResReq, resAuthSet);
QueryResultWithSchemaResp queryResultAfterDesensitization = authCommonService
.desensitizationData(queryResultWithColumns, need2Apply);
authCommonService.addPromptInfoInfo(modelId, queryResultAfterDesensitization, authorizedResource, need2Apply);

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.semantic.query.utils;
import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.Aggregator;
@@ -9,9 +10,13 @@ import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.pojo.ItemDateFilter;
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ItemDateResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.semantic.model.domain.Catalog;
@@ -25,8 +30,11 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.query.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.assertj.core.util.Lists;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -42,6 +50,8 @@ public class QueryStructUtils {
private final Catalog catalog;
@Value("${internal.metric.cnt.suffix:internal_cnt}")
private String internalMetricNameSuffix;
@Autowired
private SchemaService schemaService;
public QueryStructUtils(
DateUtils dateUtils,
@@ -147,17 +157,37 @@ public class QueryStructUtils {
sqlFilterUtils.getFiltersCol(queryStructCmd.getOriginalFilter()).stream().forEach(col -> resNameEnSet.add(col));
return resNameEnSet;
}
public Set<String> getResNameEn(QueryDslReq queryDslReq) {
Set<String> resNameEnSet = SqlParserSelectHelper.getAllFields(queryDslReq.getSql())
public Set<String> getResName(QueryDslReq queryDslReq) {
Set<String> resNameSet = SqlParserSelectHelper.getAllFields(queryDslReq.getSql())
.stream().collect(Collectors.toSet());
return resNameEnSet;
return resNameSet;
}
public Set<String> getResNameEnExceptInternalCol(QueryStructReq queryStructCmd) {
Set<String> resNameEnSet = getResNameEn(queryStructCmd);
return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet());
}
public Set<String> getResNameEnExceptInternalCol(QueryDslReq queryDslReq) {
Set<String> resNameEnSet = getResNameEn(queryDslReq);
public Set<String> getResNameEnExceptInternalCol(QueryDslReq queryDslReq, User user) {
Set<String> resNameSet = getResName(queryDslReq);
Set<String> resNameEnSet = new HashSet<>();
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
List<Long> modelIds = Lists.newArrayList(queryDslReq.getModelId());
filter.setModelIds(modelIds);
List<ModelSchemaResp> modelSchemaRespList = schemaService.fetchModelSchema(filter, user);
if (!CollectionUtils.isEmpty(modelSchemaRespList)) {
List<MetricSchemaResp> metrics = modelSchemaRespList.get(0).getMetrics();
List<DimSchemaResp> dimensions = modelSchemaRespList.get(0).getDimensions();
metrics.stream().forEach(o -> {
if (resNameSet.contains(o.getName())) {
resNameEnSet.add(o.getBizName());
}
});
dimensions.stream().forEach(o -> {
if (resNameSet.contains(o.getName())) {
resNameEnSet.add(o.getBizName());
}
});
}
return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet());
}