diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryDataReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryDataReq.java index af9a1425e..ed1395a83 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryDataReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryDataReq.java @@ -14,6 +14,7 @@ public class QueryDataReq { private Set metrics = new HashSet<>(); private Set dimensions = new HashSet<>(); private Set dimensionFilters = new HashSet<>(); + private Set metricFilters = new HashSet<>(); private DateConf dateInfo; private Long queryId = 7L; private Integer parseId = 2; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index b4d1a683b..975aeedd8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.service.impl; +import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SchemaMapper; import com.tencent.supersonic.chat.api.component.SemanticLayer; @@ -38,6 +39,8 @@ import com.tencent.supersonic.chat.utils.ComponentFactory; import java.util.Map; import com.tencent.supersonic.semantic.api.model.response.ExplainResp; +import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import java.util.List; import java.util.ArrayList; import java.util.Set; @@ -295,27 +298,73 @@ public class QueryServiceImpl implements QueryService { if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) { parseInfo.setDimensionFilters(queryData.getDimensionFilters()); } - } - if (Objects.nonNull(queryData.getDateInfo())) { - parseInfo.setDateInfo(queryData.getDateInfo()); + if (Objects.nonNull(queryData.getDateInfo())) { + parseInfo.setDateInfo(queryData.getDateInfo()); + } } if (parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE) - && CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) { + && (CollectionUtils.isNotEmpty(queryData.getDimensionFilters()) + || CollectionUtils.isNotEmpty(queryData.getMetricFilters()))) { Map> filedNameToValueMap = new HashMap<>(); String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class); LLMResp llmResp = dslParseResult.getLlmResp(); String correctorSql = llmResp.getCorrectorSql(); log.info("correctorSql before replacing:{}", correctorSql); + List filterExpressionList = SqlParserSelectHelper.getFilterExpression(correctorSql); for (QueryFilter dslQueryFilter : queryData.getDimensionFilters()) { - for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) { - if (dslQueryFilter.getBizName().equals(queryFilter.getBizName())) { - Map map = new HashMap<>(); - map.put(queryFilter.getValue().toString(), dslQueryFilter.getValue().toString()); - filedNameToValueMap.put(dslQueryFilter.getBizName(), map); + Map map = new HashMap<>(); + for (FilterExpression filterExpression : filterExpressionList) { + if (filterExpression.getFieldName().equals(dslQueryFilter.getBizName()) + && dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) { + map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString()); break; } } + filedNameToValueMap.put(dslQueryFilter.getBizName(), map); + } + for (QueryFilter dslQueryFilter : queryData.getMetricFilters()) { + Map map = new HashMap<>(); + for (FilterExpression filterExpression : filterExpressionList) { + if (filterExpression.getFieldName().equals(dslQueryFilter.getBizName()) + && dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) { + map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString()); + break; + } + } + filedNameToValueMap.put(dslQueryFilter.getBizName(), map); + } + String dateField = "sys_imp_date"; + if (Objects.nonNull(queryData.getDateInfo())) { + Map map = new HashMap<>(); + List dateFields = Lists.newArrayList("dayno", "sys_imp_date", "sys_imp_week", "sys_imp_month"); + if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) { + for (FilterExpression filterExpression : filterExpressionList) { + if (dateFields.contains(filterExpression.getFieldName())) { + dateField = filterExpression.getFieldName(); + map.put(filterExpression.getFieldValue().toString(), + queryData.getDateInfo().getStartDate()); + break; + } + } + } else { + for (FilterExpression filterExpression : filterExpressionList) { + if (dateFields.contains(filterExpression.getFieldName())) { + dateField = filterExpression.getFieldName(); + if (filterExpression.getOperator().equals(">=") + || filterExpression.getOperator().equals(">")) { + map.put(filterExpression.getFieldValue().toString(), + queryData.getDateInfo().getStartDate()); + } + if (filterExpression.getOperator().equals("<=") + || filterExpression.getOperator().equals("<")) { + map.put(filterExpression.getFieldValue().toString(), + queryData.getDateInfo().getEndDate()); + } + } + } + } + filedNameToValueMap.put(dateField, map); } log.info("filedNameToValueMap:{}", filedNameToValueMap); correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java index 915629992..f78f717fe 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java @@ -2,10 +2,19 @@ package com.tencent.supersonic.common.util.jsqlparser; import java.util.Map; import java.util.Objects; + import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.StringValue; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.DoubleValue; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; + import net.sf.jsqlparser.schema.Column; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -51,4 +60,64 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { } } -} \ No newline at end of file + public void visit(GreaterThan expr) { + replaceComparisonExpression(expr); + } + + public void visit(GreaterThanEquals expr) { + replaceComparisonExpression(expr); + } + + public void visit(MinorThanEquals expr) { + replaceComparisonExpression(expr); + } + + public void visit(MinorThan expr) { + replaceComparisonExpression(expr); + } + + public void replaceComparisonExpression(T expression) { + if ((expression instanceof GreaterThanEquals) || (expression instanceof GreaterThan) + || (expression instanceof MinorThanEquals) || (expression instanceof MinorThan)) { + Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression(); + Expression rightExpression = ((ComparisonOperator) expression).getRightExpression(); + if (!(leftExpression instanceof Column)) { + return; + } + if (CollectionUtils.isEmpty(filedNameToValueMap)) { + return; + } + if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { + return; + } + Column leftColumnName = (Column) leftExpression; + + String columnName = leftColumnName.getColumnName(); + if (StringUtils.isEmpty(columnName)) { + return; + } + Map valueMap = filedNameToValueMap.get(columnName); + if (Objects.isNull(valueMap) || valueMap.isEmpty()) { + return; + } + for (String oriValue : valueMap.keySet()) { + String replaceValue = valueMap.get(oriValue); + if (StringUtils.isNotEmpty(replaceValue)) { + if (rightExpression instanceof LongValue) { + LongValue rightStringValue = (LongValue) rightExpression; + rightStringValue.setValue(Long.parseLong(replaceValue)); + } + if (rightExpression instanceof DoubleValue) { + DoubleValue rightStringValue = (DoubleValue) rightExpression; + rightStringValue.setValue(Double.parseDouble(replaceValue)); + } + if (rightExpression instanceof StringValue) { + StringValue rightStringValue = (StringValue) rightExpression; + rightStringValue.setValue(replaceValue); + } + } + } + + } + } +} diff --git a/launchers/standalone/src/main/resources/data/dictionary/custom/benchmark_cspider.txt b/launchers/standalone/src/main/resources/data/dictionary/custom/benchmark_cspider.txt index 2c86bfce4..6a647af87 100644 --- a/launchers/standalone/src/main/resources/data/dictionary/custom/benchmark_cspider.txt +++ b/launchers/standalone/src/main/resources/data/dictionary/custom/benchmark_cspider.txt @@ -1,3 +1,40 @@ +tagore _3_8 9000 +nazrul _3_8 9000 +民间 _3_8 9000 +现代 _3_8 9000 +蓝调 _3_8 9000 +流行 _3_8 9000 +孟加拉国 _3_10 9000 +锡尔赫特、吉大港、库斯蒂亚 _3_10 9000 +加拿大 _3_10 9000 +美国 _3_10 9000 +Shrikanta _3_11 9000 +Prity _3_11 9000 +Farida _3_11 9000 +Topu _3_11 9000 +Enrique _3_11 9000 +Michel _3_11 9000 +孟加拉国 _3_12 9000 +印度 _3_12 9000 +美国 _3_12 9000 +英国 _3_12 9000 +男性 _3_13 9000 +女性 _3_13 9000 +mp4 _3_19 9000 +mp3 _3_19 9000 +Tumi#长袍#尼罗布 _3_20 9000 +舒克诺#帕塔尔#努普尔#帕埃 _3_20 9000 +阿米·奥帕尔·霍伊 _3_20 9000 +我的爱 _3_20 9000 +打败它 _3_20 9000 +阿杰伊阿卡什 _3_20 9000 +孟加拉国 _3_22 9000 +印度 _3_22 9000 +美国 _3_22 9000 +英国 _3_22 9000 +孟加拉语 _3_26 9000 +英文 _3_26 9000 +======= 孟加拉国 _3_8 9000 锡尔赫特、吉大港、库斯蒂亚 _3_8 9000 加拿大 _3_8 9000 diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java index eafe6c0d8..b7f215c56 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java @@ -59,6 +59,12 @@ public class QueryController { return queryService.queryByStructWithAuth(queryStructReq, user); } + @PostMapping("/queryStatement") + public Object queryStatement(@RequestBody QueryStatement queryStatement) throws Exception { + Object result = queryService.queryByQueryStatement(queryStatement); + return result; + } + @PostMapping("/struct/parse") public SqlParserResp parseByStruct(@RequestBody ParseSqlReq parseSqlReq, HttpServletRequest request, diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/AuthCommonService.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/AuthCommonService.java index e6a30a620..1ba1420f4 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/AuthCommonService.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/AuthCommonService.java @@ -155,9 +155,11 @@ public class AuthCommonService { boolean doDesensitization = false; for (QueryColumn queryColumn : columns) { - if (need2Apply.contains(queryColumn.getNameEn())) { - doDesensitization = true; - break; + for (String sensitiveCol : need2Apply) { + if (queryColumn.getNameEn().contains(sensitiveCol)) { + doDesensitization = true; + break; + } } } if (!doDesensitization) { @@ -192,8 +194,15 @@ public class AuthCommonService { Map row = result.get(i); Map newRow = new HashMap<>(); for (String col : row.keySet()) { - if (need2Apply.contains(col)) { - newRow.put(col, "****"); + boolean sensitive = false; + for (String sensitiveCol : need2Apply) { + if (col.contains(sensitiveCol)) { + sensitive = true; + break; + } + } + if (sensitive) { + newRow.put(col, "******"); } else { newRow.put(col, row.get(col)); } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java index 8e8a6e4c8..a7ecd6d73 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java @@ -10,6 +10,8 @@ import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.response.ItemUseResp; +import com.tencent.supersonic.semantic.query.persistence.pojo.QueryStatement; + import java.util.List; public interface QueryService { @@ -25,6 +27,8 @@ public interface QueryService { QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user); + Object queryByQueryStatement(QueryStatement queryStatement); + List getStatInfo(ItemUseReq itemUseCommend); ExplainResp explain(ExplainSqlReq explainSqlReq, User user) throws Exception; diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java index ea6c675a9..f7061a188 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java @@ -77,11 +77,17 @@ public class QueryServiceImpl implements QueryService { } catch (Exception e) { log.info("convertToQueryStatement has a exception:{}", e.toString()); } + log.info("queryStatement:{}", queryStatement); QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement); statUtils.statInfo2DbAsync(TaskStatusEnum.SUCCESS); return results; } + public Object queryByQueryStatement(QueryStatement queryStatement) { + QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement); + return results; + } + private QueryStatement convertToQueryStatement(QueryDslReq querySqlCmd, User user) throws Exception { ModelSchemaFilterReq filter = new ModelSchemaFilterReq(); List modelIds = new ArrayList<>(); diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java index 15d706c6f..0fb93e8ca 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java @@ -84,7 +84,7 @@ public class DslDataAspect { authCommonService.doModelVisible(user, modelId); // 3. fetch data permission meta information - Set res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(queryDslReq); + Set res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(queryDslReq, user); log.info("modelId:{}, res4Privilege:{}", modelId, res4Privilege); Set sensitiveResByModel = authCommonService.getHighSensitiveColsByModelId(modelId); @@ -117,6 +117,7 @@ public class DslDataAspect { // 6.if the column has no permission, hit * Set need2Apply = sensitiveResReq.stream().filter(req -> !resAuthSet.contains(req)) .collect(Collectors.toSet()); + log.info("need2Apply:{},sensitiveResReq:{},resAuthSet:{}", need2Apply, sensitiveResReq, resAuthSet); QueryResultWithSchemaResp queryResultAfterDesensitization = authCommonService .desensitizationData(queryResultWithColumns, need2Apply); authCommonService.addPromptInfoInfo(modelId, queryResultAfterDesensitization, authorizedResource, need2Apply); diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java index 65add3e75..94c6c06a4 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.semantic.query.utils; import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.Aggregator; @@ -9,9 +10,13 @@ import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import com.tencent.supersonic.semantic.api.model.pojo.ItemDateFilter; +import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.ItemDateResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; +import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp; +import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.model.domain.Catalog; @@ -25,8 +30,11 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import com.tencent.supersonic.semantic.query.service.SchemaService; import lombok.extern.slf4j.Slf4j; import org.apache.logging.log4j.util.Strings; +import org.assertj.core.util.Lists; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -42,6 +50,8 @@ public class QueryStructUtils { private final Catalog catalog; @Value("${internal.metric.cnt.suffix:internal_cnt}") private String internalMetricNameSuffix; + @Autowired + private SchemaService schemaService; public QueryStructUtils( DateUtils dateUtils, @@ -147,17 +157,37 @@ public class QueryStructUtils { sqlFilterUtils.getFiltersCol(queryStructCmd.getOriginalFilter()).stream().forEach(col -> resNameEnSet.add(col)); return resNameEnSet; } - public Set getResNameEn(QueryDslReq queryDslReq) { - Set resNameEnSet = SqlParserSelectHelper.getAllFields(queryDslReq.getSql()) + public Set getResName(QueryDslReq queryDslReq) { + Set resNameSet = SqlParserSelectHelper.getAllFields(queryDslReq.getSql()) .stream().collect(Collectors.toSet()); - return resNameEnSet; + return resNameSet; } public Set getResNameEnExceptInternalCol(QueryStructReq queryStructCmd) { Set resNameEnSet = getResNameEn(queryStructCmd); return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet()); } - public Set getResNameEnExceptInternalCol(QueryDslReq queryDslReq) { - Set resNameEnSet = getResNameEn(queryDslReq); + + public Set getResNameEnExceptInternalCol(QueryDslReq queryDslReq, User user) { + Set resNameSet = getResName(queryDslReq); + Set resNameEnSet = new HashSet<>(); + ModelSchemaFilterReq filter = new ModelSchemaFilterReq(); + List modelIds = Lists.newArrayList(queryDslReq.getModelId()); + filter.setModelIds(modelIds); + List modelSchemaRespList = schemaService.fetchModelSchema(filter, user); + if (!CollectionUtils.isEmpty(modelSchemaRespList)) { + List metrics = modelSchemaRespList.get(0).getMetrics(); + List dimensions = modelSchemaRespList.get(0).getDimensions(); + metrics.stream().forEach(o -> { + if (resNameSet.contains(o.getName())) { + resNameEnSet.add(o.getBizName()); + } + }); + dimensions.stream().forEach(o -> { + if (resNameSet.contains(o.getName())) { + resNameEnSet.add(o.getBizName()); + } + }); + } return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet()); }