diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/knowledge/semantic/LocalSemanticInterpreter.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/knowledge/semantic/LocalSemanticInterpreter.java index 41c1f91b8..fbb7e54e9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/knowledge/semantic/LocalSemanticInterpreter.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/knowledge/semantic/LocalSemanticInterpreter.java @@ -49,14 +49,14 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter { return queryByS2SQL(querySQLReq, user); } queryService = ContextUtils.getBean(QueryService.class); - return queryService.queryByStructWithAuth(queryStructReq, user); + return queryService.queryByReq(queryStructReq, user); } @Override public SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user) { try { queryService = ContextUtils.getBean(QueryService.class); - return queryService.queryByMultiStruct(queryMultiStructReq, user); + return queryService.queryByReq(queryMultiStructReq, user); } catch (Exception e) { log.info("queryByMultiStruct has an exception:{}", e); } @@ -67,7 +67,7 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter { @SneakyThrows public SemanticQueryResp queryByS2SQL(QuerySqlReq querySQLReq, User user) { queryService = ContextUtils.getBean(QueryService.class); - SemanticQueryResp object = queryService.queryBySql(querySQLReq, user); + SemanticQueryResp object = queryService.queryByReq(querySQLReq, user); return JsonUtil.toObject(JsonUtil.toString(object), SemanticQueryResp.class); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/utils/DictQueryHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/utils/DictQueryHelper.java index cf4a7ccce..56b2ecc15 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/utils/DictQueryHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/utils/DictQueryHelper.java @@ -181,6 +181,7 @@ public class DictQueryHelper { queryStructCmd.setDateInfo(dateInfo); queryStructCmd.setLimit(dimMaxLimit); + queryStructCmd.setNeedAuth(false); return queryStructCmd; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SemanticQueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SemanticQueryReq.java index 1624bbbb3..5b310ec39 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SemanticQueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SemanticQueryReq.java @@ -17,6 +17,8 @@ import org.apache.commons.codec.digest.DigestUtils; @Slf4j public abstract class SemanticQueryReq { + protected boolean needAuth = true; + protected Set modelIds; protected List params = new ArrayList<>(); @@ -45,4 +47,11 @@ public abstract class SemanticQueryReq { return modelIds; } + public boolean isNeedAuth() { + return needAuth; + } + + public void setNeedAuth(boolean needAuth) { + this.needAuth = needAuth; + } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java new file mode 100644 index 000000000..8d6884e47 --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java @@ -0,0 +1,67 @@ +package com.tencent.supersonic.headless.core.cache; + + +import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +@Component +@Slf4j +public class DefaultQueryCache implements QueryCache { + + @Value("${query.cache.enable:true}") + private Boolean cacheEnable; + @Autowired + private CacheManager cacheManager; + + public Object query(SemanticQueryReq semanticQueryReq) { + String cacheKey = getCacheKey(semanticQueryReq); + if (isCache(semanticQueryReq)) { + Object result = cacheManager.get(cacheKey); + log.info("queryFromCache, key:{}, semanticQueryReq:{}", cacheKey, semanticQueryReq); + return result; + } + return null; + } + + public Boolean put(SemanticQueryReq semanticQueryReq, Object value) { + if (cacheEnable && Objects.nonNull(value)) { + String key = getCacheKey(semanticQueryReq); + CompletableFuture.supplyAsync(() -> cacheManager.put(key, value)) + .exceptionally(exception -> { + log.warn("exception:", exception); + return null; + }); + log.info("add record to cache, key:{}", key); + return true; + } + return false; + } + + public String getCacheKey(SemanticQueryReq semanticQueryReq) { + String commandMd5 = semanticQueryReq.generateCommandMd5(); + String keyByModelIds = getKeyByModelIds(semanticQueryReq.getModelIds()); + return cacheManager.generateCacheKey(keyByModelIds, commandMd5); + } + + private String getKeyByModelIds(List modelIds) { + return String.join(",", modelIds.stream().map(Object::toString).collect(Collectors.toList())); + } + + private boolean isCache(SemanticQueryReq semanticQueryReq) { + if (!cacheEnable) { + return false; + } + if (semanticQueryReq.getCacheInfo() != null) { + return semanticQueryReq.getCacheInfo().getCache(); + } + return false; + } + +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/QueryCache.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/QueryCache.java index b964498df..9eb2f3b05 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/QueryCache.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/QueryCache.java @@ -1,78 +1,14 @@ package com.tencent.supersonic.headless.core.cache; -import com.tencent.supersonic.headless.api.pojo.Cache; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import java.util.List; -import java.util.Objects; -import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; -@Component -@Slf4j -public class QueryCache { +public interface QueryCache { - @Value("${query.cache.enable:true}") - private Boolean cacheEnable; - @Autowired - private CacheManager cacheManager; + Object query(SemanticQueryReq semanticQueryReq); - public Object query(SemanticQueryReq semanticQueryReq) { - String cacheKey = getCacheKey(semanticQueryReq); - handleGlobalCacheDisable(semanticQueryReq); - boolean isCache = isCache(semanticQueryReq); - if (isCache) { - Object result = cacheManager.get(cacheKey); - log.info("queryFromCache, key:{}, semanticQueryReq:{}", cacheKey, semanticQueryReq); - return result; - } - return null; - } + Boolean put(SemanticQueryReq semanticQueryReq, Object value); - public Boolean put(SemanticQueryReq semanticQueryReq, Object value) { - if (cacheEnable && Objects.nonNull(value)) { - String key = getCacheKey(semanticQueryReq); - CompletableFuture.supplyAsync(() -> cacheManager.put(key, value)) - .exceptionally(exception -> { - log.warn("exception:", exception); - return null; - }); - log.info("add record to cache, key:{}", key); - return true; - } - return false; - } - - public String getCacheKey(SemanticQueryReq semanticQueryReq) { - String commandMd5 = semanticQueryReq.generateCommandMd5(); - String keyByModelIds = getKeyByModelIds(semanticQueryReq.getModelIds()); - return cacheManager.generateCacheKey(keyByModelIds, commandMd5); - } - - private void handleGlobalCacheDisable(SemanticQueryReq semanticQueryReq) { - if (!cacheEnable) { - Cache cacheInfo = new Cache(); - cacheInfo.setCache(false); - semanticQueryReq.setCacheInfo(cacheInfo); - } - } - - private String getKeyByModelIds(List modelIds) { - return String.join(",", modelIds.stream().map(Object::toString).collect(Collectors.toList())); - } - - private boolean isCache(SemanticQueryReq semanticQueryReq) { - if (!cacheEnable) { - return false; - } - if (semanticQueryReq.getCacheInfo() != null) { - return semanticQueryReq.getCacheInfo().getCache(); - } - return false; - } + String getCacheKey(SemanticQueryReq semanticQueryReq); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2SQLDataPermission.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2DataPermission.java similarity index 89% rename from headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2SQLDataPermission.java rename to headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2DataPermission.java index 81df3dcb1..3eaba07f8 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2SQLDataPermission.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2DataPermission.java @@ -9,6 +9,6 @@ import java.lang.annotation.Documented; @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented -public @interface S2SQLDataPermission { +public @interface S2DataPermission { } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/StructDataPermission.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/StructDataPermission.java deleted file mode 100644 index 9744b8d99..000000000 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/StructDataPermission.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.tencent.supersonic.headless.server.annotation; - -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -@Target({ElementType.PARAMETER, ElementType.METHOD}) -@Retention(RetentionPolicy.RUNTIME) -public @interface StructDataPermission { - -} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java index 0e99d1b54..c84e77433 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java @@ -4,17 +4,26 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.headless.api.pojo.DimValueMap; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; -import com.tencent.supersonic.headless.api.pojo.DimValueMap; +import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.service.DimensionService; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; @@ -26,14 +35,6 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - @Aspect @Component @Slf4j @@ -41,51 +42,79 @@ public class DimValueAspect { @Value("${dimension.value.map.enable:true}") private Boolean dimensionValueMapEnable; - - @Value("${dimension.value.map.sql.enable:true}") - private Boolean dimensionValueMapSqlEnable; @Autowired private DimensionService dimensionService; - @Around("execution(* com.tencent.supersonic.headless.server.service.impl.QueryServiceImpl.queryBySql(..))") - public Object handleSqlDimValue(ProceedingJoinPoint joinPoint) throws Throwable { - if (!dimensionValueMapSqlEnable) { - log.debug("sql dimensionValueMapEnable is false, skip dimensionValueMap"); + @Around("execution(* com.tencent.supersonic.headless.server.service.QueryService.queryByReq(..))") + public Object handleDimValue(ProceedingJoinPoint joinPoint) throws Throwable { + if (!dimensionValueMapEnable) { + log.debug("dimensionValueMapEnable is false, skip dimensionValueMap"); SemanticQueryResp queryResultWithColumns = (SemanticQueryResp) joinPoint.proceed(); return queryResultWithColumns; } + + Object[] args = joinPoint.getArgs(); + SemanticQueryReq queryReq = (SemanticQueryReq) args[0]; + if (queryReq instanceof QueryStructReq) { + return handleStructDimValue(joinPoint); + } + if (queryReq instanceof QuerySqlReq) { + return handleSqlDimValue(joinPoint); + } + throw new InvalidArgumentException("queryReq is not Invalid:" + queryReq); + } + + private SemanticQueryResp handleStructDimValue(ProceedingJoinPoint joinPoint) throws Throwable { + Object[] args = joinPoint.getArgs(); + QueryStructReq queryStructReq = (QueryStructReq) args[0]; + MetaFilter metaFilter = new MetaFilter(Lists.newArrayList(queryStructReq.getModelIds())); + List dimensions = dimensionService.getDimensions(metaFilter); + Map> dimAndAliasAndTechNamePair = getAliasAndBizNameToTechName(dimensions); + Map> dimAndTechNameAndBizNamePair = getTechNameToBizName(dimensions); + + rewriteFilter(queryStructReq.getDimensionFilters(), dimAndAliasAndTechNamePair); + + SemanticQueryResp semanticQueryResp = (SemanticQueryResp) joinPoint.proceed(); + if (Objects.nonNull(semanticQueryResp)) { + rewriteDimValue(semanticQueryResp, dimAndTechNameAndBizNamePair); + } + + return semanticQueryResp; + } + + public Object handleSqlDimValue(ProceedingJoinPoint joinPoint) throws Throwable { Object[] args = joinPoint.getArgs(); QuerySqlReq querySQLReq = (QuerySqlReq) args[0]; MetaFilter metaFilter = new MetaFilter(Lists.newArrayList(querySQLReq.getModelIds())); String sql = querySQLReq.getSql(); log.info("correctorSql before replacing:{}", sql); - // if dimensionvalue is alias,consider the true dimensionvalue. List fieldExpressionList = SqlParserSelectHelper.getWhereExpressions(sql); List dimensions = dimensionService.getDimensions(metaFilter); Set fieldNames = dimensions.stream().map(o -> o.getName()).collect(Collectors.toSet()); Map> filedNameToValueMap = new HashMap<>(); - fieldExpressionList.stream().forEach(expression -> { - if (fieldNames.contains(expression.getFieldName())) { - dimensions.stream().forEach(dimension -> { - if (expression.getFieldName().equals(dimension.getName()) - && !CollectionUtils.isEmpty(dimension.getDimValueMaps())) { - // consider '=' filter - if (expression.getOperator().equals(FilterOperatorEnum.EQUALS.getValue())) { - dimension.getDimValueMaps().stream().forEach(dimValue -> { - if (!CollectionUtils.isEmpty(dimValue.getAlias()) - && dimValue.getAlias().contains(expression.getFieldValue().toString())) { - getFiledNameToValueMap(filedNameToValueMap, expression.getFieldValue().toString(), - dimValue.getTechName(), expression.getFieldName()); - } - }); - } - // consider 'in' filter,each element needs to judge. - replaceInCondition(expression, dimension, filedNameToValueMap); - } - }); + for (FieldExpression expression : fieldExpressionList) { + if (!fieldNames.contains(expression.getFieldName())) { + continue; } - }); - log.info("filedNameToValueMap:{}", filedNameToValueMap); + for (DimensionResp dimension : dimensions) { + if (!expression.getFieldName().equals(dimension.getName()) + || CollectionUtils.isEmpty(dimension.getDimValueMaps())) { + continue; + } + // consider '=' filter + if (expression.getOperator().equals(FilterOperatorEnum.EQUALS.getValue())) { + dimension.getDimValueMaps().stream().forEach(dimValue -> { + if (!CollectionUtils.isEmpty(dimValue.getAlias()) + && dimValue.getAlias().contains(expression.getFieldValue().toString())) { + getFiledNameToValueMap(filedNameToValueMap, expression.getFieldValue().toString(), + dimValue.getTechName(), expression.getFieldName()); + } + }); + } + // consider 'in' filter,each element needs to judge. + replaceInCondition(expression, dimension, filedNameToValueMap); + } + } sql = SqlParserReplaceHelper.replaceValue(sql, filedNameToValueMap); log.info("correctorSql after replacing:{}", sql); querySQLReq.setSql(sql); @@ -99,7 +128,7 @@ public class DimValueAspect { } public void replaceInCondition(FieldExpression expression, DimensionResp dimension, - Map> filedNameToValueMap) { + Map> filedNameToValueMap) { if (expression.getOperator().equals(FilterOperatorEnum.IN.getValue())) { String fieldValue = JsonUtil.toString(expression.getFieldValue()); fieldValue = fieldValue.replace("'", ""); @@ -127,40 +156,12 @@ public class DimValueAspect { } public void getFiledNameToValueMap(Map> filedNameToValueMap, - String oldValue, String newValue, String fieldName) { + String oldValue, String newValue, String fieldName) { Map map = new HashMap<>(); map.put(oldValue, newValue); filedNameToValueMap.put(fieldName, map); } - @Around("execution(* com.tencent.supersonic.headless.server.rest.QueryController.queryByStruct(..))" - + " || execution(* com.tencent.supersonic.headless.server.service.QueryService.queryByStruct(..))" - + " || execution(* com.tencent.supersonic.headless.server.service.QueryService.queryByStructWithAuth(..))") - public Object handleDimValue(ProceedingJoinPoint joinPoint) throws Throwable { - - if (!dimensionValueMapEnable) { - log.debug("dimensionValueMapEnable is false, skip dimensionValueMap"); - SemanticQueryResp queryResultWithColumns = (SemanticQueryResp) joinPoint.proceed(); - return queryResultWithColumns; - } - - Object[] args = joinPoint.getArgs(); - QueryStructReq queryStructReq = (QueryStructReq) args[0]; - MetaFilter metaFilter = new MetaFilter(Lists.newArrayList(queryStructReq.getModelIds())); - List dimensions = dimensionService.getDimensions(metaFilter); - Map> dimAndAliasAndTechNamePair = getAliasAndBizNameToTechName(dimensions); - Map> dimAndTechNameAndBizNamePair = getTechNameToBizName(dimensions); - - rewriteFilter(queryStructReq.getDimensionFilters(), dimAndAliasAndTechNamePair); - - SemanticQueryResp semanticQueryResp = (SemanticQueryResp) joinPoint.proceed(); - if (Objects.nonNull(semanticQueryResp)) { - rewriteDimValue(semanticQueryResp, dimAndTechNameAndBizNamePair); - } - - return semanticQueryResp; - } - private void rewriteDimValue(SemanticQueryResp semanticQueryResp, Map> dimAndTechNameAndBizNamePair) { if (!selectDimValueMap(semanticQueryResp.getColumns(), dimAndTechNameAndBizNamePair)) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2SQLDataAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java similarity index 52% rename from headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2SQLDataAspect.java rename to headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java index a5f5a9ec3..3b4d8d5ce 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2SQLDataAspect.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java @@ -7,9 +7,14 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.Filter; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; +import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; +import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; @@ -21,6 +26,7 @@ import com.tencent.supersonic.headless.server.utils.QueryStructUtils; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.StringJoiner; @@ -44,7 +50,7 @@ import org.springframework.util.CollectionUtils; @Aspect @Order(1) @Slf4j -public class S2SQLDataAspect extends AuthCheckBaseAspect { +public class S2DataPermissionAspect extends AuthCheckBaseAspect { @Autowired private QueryStructUtils queryStructUtils; @@ -55,33 +61,50 @@ public class S2SQLDataAspect extends AuthCheckBaseAspect { @Value("${permission.data.enable:true}") private Boolean permissionDataEnable; - @Pointcut("@annotation(com.tencent.supersonic.headless.server.annotation.S2SQLDataPermission)") - private void s2SQLPermissionCheck() { + @Pointcut("@annotation(com.tencent.supersonic.headless.server.annotation.S2DataPermission)") + private void s2PermissionCheck() { } - @Around("s2SQLPermissionCheck()") + @Around("s2PermissionCheck()") public Object doAround(ProceedingJoinPoint joinPoint) throws Throwable { - log.info("s2SQL permission check!"); - Object[] objects = joinPoint.getArgs(); - QuerySqlReq querySQLReq = (QuerySqlReq) objects[0]; - User user = (User) objects[1]; + log.info("s2 permission check!"); if (!permissionDataEnable) { - log.info("not to check s2SQL permission!"); + log.info("not to check permission!"); return joinPoint.proceed(); } + Object[] objects = joinPoint.getArgs(); + SemanticQueryReq queryReq = (SemanticQueryReq) objects[0]; + if (!queryReq.isNeedAuth()) { + log.info("needAuth is false, there is no need to check permissions."); + return joinPoint.proceed(); + } + User user = (User) objects[1]; if (Objects.isNull(user) || Strings.isNullOrEmpty(user.getName())) { throw new RuntimeException("please provide user information"); } - List modelIds = querySQLReq.getModelIds(); - //1. determine whether admin of the model - if (doModelAdmin(user, modelIds)) { - log.info("determine whether admin of the model!"); + // determine whether admin of the model + if (doModelAdmin(user, queryReq.getModelIds())) { return joinPoint.proceed(); } - // 2. determine whether the subject field is visible - doModelVisible(user, modelIds); - // 3. fetch data permission meta information + // determine whether the subject field is visible + doModelVisible(user, queryReq.getModelIds()); + + if (queryReq instanceof QuerySqlReq) { + return checkSqlPermission(joinPoint, (QuerySqlReq) queryReq); + } + if (queryReq instanceof QueryStructReq) { + return checkStructPermission(joinPoint, (QueryStructReq) queryReq); + } + throw new InvalidArgumentException("queryReq is not Invalid:" + queryReq); + } + + private Object checkSqlPermission(ProceedingJoinPoint joinPoint, QuerySqlReq querySQLReq) + throws Throwable { + Object[] objects = joinPoint.getArgs(); + User user = (User) objects[1]; + List modelIds = querySQLReq.getModelIds(); + // fetch data permission meta information Set res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(querySQLReq, user); log.info("modelId:{}, res4Privilege:{}", modelIds, res4Privilege); @@ -95,13 +118,13 @@ public class S2SQLDataAspect extends AuthCheckBaseAspect { // get sensitiveRes that user has privilege Set resAuthSet = getAuthResNameSet(authorizedResource, modelIds); - // 4.if sensitive fields without permission are involved in filter, thrown an exception + // if sensitive fields without permission are involved in filter, thrown an exception doFilterCheckLogic(querySQLReq, resAuthSet, sensitiveResReq); - // 5.row permission pre-filter + // row permission pre-filter doRowPermission(querySQLReq, authorizedResource); - // 6.proceed + // proceed SemanticQueryResp queryResultWithColumns = (SemanticQueryResp) joinPoint.proceed(); if (CollectionUtils.isEmpty(sensitiveResReq) || allSensitiveResReqIsOk(sensitiveResReq, resAuthSet)) { @@ -110,7 +133,7 @@ public class S2SQLDataAspect extends AuthCheckBaseAspect { return getQueryResultWithColumns(queryResultWithColumns, modelIds, authorizedResource); } - // 6.if the column has no permission, hit * + // if the column has no permission, hit * Set need2Apply = sensitiveResReq.stream().filter(req -> !resAuthSet.contains(req)) .collect(Collectors.toSet()); log.info("need2Apply:{},sensitiveResReq:{},resAuthSet:{}", need2Apply, sensitiveResReq, resAuthSet); @@ -121,6 +144,111 @@ public class S2SQLDataAspect extends AuthCheckBaseAspect { return queryResultAfterDesensitization; } + private void doFilterCheckLogic(QuerySqlReq querySQLReq, Set resAuthName, + Set sensitiveResReq) { + Set resFilterSet = queryStructUtils.getFilterResNameEnExceptInternalCol(querySQLReq); + Set need2Apply = resFilterSet.stream() + .filter(res -> !resAuthName.contains(res) && sensitiveResReq.contains(res)).collect(Collectors.toSet()); + Set nameCnSet = new HashSet<>(); + + List modelIds = Lists.newArrayList(querySQLReq.getModelIds()); + ModelFilter modelFilter = new ModelFilter(); + modelFilter.setModelIds(modelIds); + List modelInfos = modelService.getModelList(modelFilter); + String modelNameCn = Constants.EMPTY; + if (!CollectionUtils.isEmpty(modelInfos)) { + modelNameCn = modelInfos.get(0).getName(); + } + MetaFilter metaFilter = new MetaFilter(modelIds); + List dimensionDescList = dimensionService.getDimensions(metaFilter); + String finalDomainNameCn = modelNameCn; + dimensionDescList.stream().filter(dim -> need2Apply.contains(dim.getBizName())) + .forEach(dim -> nameCnSet.add(finalDomainNameCn + MINUS + dim.getName())); + + if (!CollectionUtils.isEmpty(need2Apply)) { + ModelResp modelResp = modelInfos.get(0); + List admins = modelService.getModelAdmin(modelResp.getId()); + log.info("in doFilterLogic, need2Apply:{}", need2Apply); + String message = String.format("您没有以下维度%s权限, 请联系管理员%s开通", nameCnSet, admins); + throw new InvalidPermissionException(message); + } + } + + private void doFilterCheckLogic(QueryStructReq queryStructReq, Set resAuthName, + Set sensitiveResReq) { + Set resFilterSet = queryStructUtils.getFilterResNameEnExceptInternalCol(queryStructReq); + Set need2Apply = resFilterSet.stream() + .filter(res -> !resAuthName.contains(res) && sensitiveResReq.contains(res)).collect(Collectors.toSet()); + Set nameCnSet = new HashSet<>(); + + Map modelRespMap = modelService.getModelMap(); + List modelIds = Lists.newArrayList(queryStructReq.getModelIds()); + List dimensionDescList = dimensionService.getDimensions(new MetaFilter(modelIds)); + dimensionDescList.stream().filter(dim -> need2Apply.contains(dim.getBizName())) + .forEach(dim -> nameCnSet.add(modelRespMap.get(dim.getModelId()).getName() + MINUS + dim.getName())); + + if (!CollectionUtils.isEmpty(need2Apply)) { + List admins = modelService.getModelAdmin(modelIds.get(0)); + log.info("in doFilterLogic, need2Apply:{}", need2Apply); + String message = String.format("您没有以下维度%s权限, 请联系管理员%s开通", nameCnSet, admins); + throw new InvalidPermissionException(message); + } + } + + public Object checkStructPermission(ProceedingJoinPoint point, QueryStructReq queryStructReq) throws Throwable { + Object[] args = point.getArgs(); + User user = (User) args[1]; + // fetch data permission meta information + List modelIds = queryStructReq.getModelIds(); + Set res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(queryStructReq); + log.info("modelId:{}, res4Privilege:{}", modelIds, res4Privilege); + + Set sensitiveResByModel = getHighSensitiveColsByModelId(modelIds); + Set sensitiveResReq = res4Privilege.parallelStream() + .filter(sensitiveResByModel::contains).collect(Collectors.toSet()); + log.info("this query domainId:{}, sensitiveResReq:{}", modelIds, sensitiveResReq); + + // query user privilege info + AuthorizedResourceResp authorizedResource = getAuthorizedResource(user, + modelIds, sensitiveResReq); + // get sensitiveRes that user has privilege + Set resAuthSet = getAuthResNameSet(authorizedResource, + queryStructReq.getModelIds()); + + // if sensitive fields without permission are involved in filter, thrown an exception + doFilterCheckLogic(queryStructReq, resAuthSet, sensitiveResReq); + + // row permission pre-filter + doRowPermission(queryStructReq, authorizedResource); + + // proceed + SemanticQueryResp queryResultWithColumns = (SemanticQueryResp) point.proceed(); + + if (CollectionUtils.isEmpty(sensitiveResReq) || allSensitiveResReqIsOk(sensitiveResReq, resAuthSet)) { + // if sensitiveRes is empty + log.info("sensitiveResReq is empty"); + return getQueryResultWithColumns(queryResultWithColumns, modelIds, authorizedResource); + } + + // if the column has no permission, hit * + Set need2Apply = sensitiveResReq.stream().filter(req -> !resAuthSet.contains(req)) + .collect(Collectors.toSet()); + SemanticQueryResp queryResultAfterDesensitization = + desensitizationData(queryResultWithColumns, need2Apply); + addPromptInfoInfo(modelIds, queryResultAfterDesensitization, authorizedResource, need2Apply); + + return queryResultAfterDesensitization; + + } + + public boolean allSensitiveResReqIsOk(Set sensitiveResReq, Set resAuthSet) { + if (resAuthSet.containsAll(sensitiveResReq)) { + return true; + } + log.info("sensitiveResReq:{}, resAuthSet:{}", sensitiveResReq, resAuthSet); + return false; + } + private void doRowPermission(QuerySqlReq querySQLReq, AuthorizedResourceResp authorizedResource) { log.debug("start doRowPermission logic"); StringJoiner joiner = new StringJoiner(" OR "); @@ -154,33 +282,36 @@ public class S2SQLDataAspect extends AuthCheckBaseAspect { } - private void doFilterCheckLogic(QuerySqlReq querySQLReq, Set resAuthName, - Set sensitiveResReq) { - Set resFilterSet = queryStructUtils.getFilterResNameEnExceptInternalCol(querySQLReq); - Set need2Apply = resFilterSet.stream() - .filter(res -> !resAuthName.contains(res) && sensitiveResReq.contains(res)).collect(Collectors.toSet()); - Set nameCnSet = new HashSet<>(); - - List modelIds = Lists.newArrayList(querySQLReq.getModelIds()); - ModelFilter modelFilter = new ModelFilter(); - modelFilter.setModelIds(modelIds); - List modelInfos = modelService.getModelList(modelFilter); - String modelNameCn = Constants.EMPTY; - if (!CollectionUtils.isEmpty(modelInfos)) { - modelNameCn = modelInfos.get(0).getName(); + private void doRowPermission(QueryStructReq queryStructReq, AuthorizedResourceResp authorizedResource) { + log.debug("start doRowPermission logic"); + StringJoiner joiner = new StringJoiner(" OR "); + List dimensionFilters = new ArrayList<>(); + if (!CollectionUtils.isEmpty(authorizedResource.getFilters())) { + authorizedResource.getFilters().stream() + .forEach(filter -> dimensionFilters.addAll(filter.getExpressions())); } - MetaFilter metaFilter = new MetaFilter(modelIds); - List dimensionDescList = dimensionService.getDimensions(metaFilter); - String finalDomainNameCn = modelNameCn; - dimensionDescList.stream().filter(dim -> need2Apply.contains(dim.getBizName())) - .forEach(dim -> nameCnSet.add(finalDomainNameCn + MINUS + dim.getName())); - if (!CollectionUtils.isEmpty(need2Apply)) { - ModelResp modelResp = modelInfos.get(0); - List admins = modelService.getModelAdmin(modelResp.getId()); - log.info("in doFilterLogic, need2Apply:{}", need2Apply); - String message = String.format("您没有以下维度%s权限, 请联系管理员%s开通", nameCnSet, admins); - throw new InvalidPermissionException(message); + if (CollectionUtils.isEmpty(dimensionFilters)) { + log.debug("dimensionFilters is empty"); + return; } + + dimensionFilters.stream().forEach(filter -> { + if (StringUtils.isNotEmpty(filter) && StringUtils.isNotEmpty(filter.trim())) { + joiner.add(" ( " + filter + " ) "); + } + }); + + if (StringUtils.isNotEmpty(joiner.toString())) { + log.info("before doRowPermission, queryStructReq:{}", queryStructReq); + Filter filter = new Filter("", FilterOperatorEnum.SQL_PART, joiner.toString()); + List filters = Objects.isNull(queryStructReq.getOriginalFilter()) ? new ArrayList<>() + : queryStructReq.getOriginalFilter(); + filters.add(filter); + queryStructReq.setDimensionFilters(filters); + log.info("after doRowPermission, queryStructReq:{}", queryStructReq); + } + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/StructDataAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/StructDataAspect.java deleted file mode 100644 index 36bc32a2a..000000000 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/StructDataAspect.java +++ /dev/null @@ -1,183 +0,0 @@ -package com.tencent.supersonic.headless.server.aspect; - -import com.google.common.base.Strings; -import com.google.common.collect.Lists; -import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; -import com.tencent.supersonic.common.pojo.Filter; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException; -import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; -import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; -import com.tencent.supersonic.headless.api.pojo.response.ModelResp; -import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; -import com.tencent.supersonic.headless.server.utils.QueryStructUtils; -import com.tencent.supersonic.headless.server.pojo.MetaFilter; -import com.tencent.supersonic.headless.server.service.DimensionService; -import com.tencent.supersonic.headless.server.service.ModelService; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.aspectj.lang.ProceedingJoinPoint; -import org.aspectj.lang.annotation.Around; -import org.aspectj.lang.annotation.Aspect; -import org.aspectj.lang.annotation.Pointcut; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; -import org.springframework.util.CollectionUtils; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.StringJoiner; -import java.util.stream.Collectors; - -import static com.tencent.supersonic.common.pojo.Constants.MINUS; - -@Component -@Aspect -@Slf4j -public class StructDataAspect extends AuthCheckBaseAspect { - @Autowired - private QueryStructUtils queryStructUtils; - @Autowired - private DimensionService dimensionService; - @Autowired - private ModelService modelService; - @Value("${permission.data.enable:true}") - private Boolean permissionDataEnable; - - @Pointcut("@annotation(com.tencent.supersonic.headless.server.annotation.StructDataPermission)") - public void dataPermissionAOP() { - } - - @Around(value = "dataPermissionAOP()") - public Object around(ProceedingJoinPoint point) throws Throwable { - Object[] args = point.getArgs(); - QueryStructReq queryStructReq = (QueryStructReq) args[0]; - User user = (User) args[1]; - - if (!permissionDataEnable) { - log.info("permissionDataEnable is false"); - return point.proceed(); - } - - if (Objects.isNull(user) || Strings.isNullOrEmpty(user.getName())) { - throw new RuntimeException("lease provide user information"); - } - //1. determine whether admin of the model - if (doModelAdmin(user, queryStructReq.getModelIds())) { - return point.proceed(); - } - - // 2. determine whether the subject field is visible - doModelVisible(user, queryStructReq.getModelIds()); - - // 3. fetch data permission meta information - List modelIds = queryStructReq.getModelIds(); - Set res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(queryStructReq); - log.info("modelId:{}, res4Privilege:{}", modelIds, res4Privilege); - - Set sensitiveResByModel = getHighSensitiveColsByModelId(modelIds); - Set sensitiveResReq = res4Privilege.parallelStream() - .filter(sensitiveResByModel::contains).collect(Collectors.toSet()); - log.info("this query domainId:{}, sensitiveResReq:{}", modelIds, sensitiveResReq); - - // query user privilege info - AuthorizedResourceResp authorizedResource = getAuthorizedResource(user, - modelIds, sensitiveResReq); - // get sensitiveRes that user has privilege - Set resAuthSet = getAuthResNameSet(authorizedResource, - queryStructReq.getModelIds()); - - // 4.if sensitive fields without permission are involved in filter, thrown an exception - doFilterCheckLogic(queryStructReq, resAuthSet, sensitiveResReq); - - // 5.row permission pre-filter - doRowPermission(queryStructReq, authorizedResource); - - // 6.proceed - SemanticQueryResp queryResultWithColumns = (SemanticQueryResp) point.proceed(); - - if (CollectionUtils.isEmpty(sensitiveResReq) || allSensitiveResReqIsOk(sensitiveResReq, resAuthSet)) { - // if sensitiveRes is empty - log.info("sensitiveResReq is empty"); - return getQueryResultWithColumns(queryResultWithColumns, modelIds, authorizedResource); - } - - // 6.if the column has no permission, hit * - Set need2Apply = sensitiveResReq.stream().filter(req -> !resAuthSet.contains(req)) - .collect(Collectors.toSet()); - SemanticQueryResp queryResultAfterDesensitization = - desensitizationData(queryResultWithColumns, need2Apply); - addPromptInfoInfo(modelIds, queryResultAfterDesensitization, authorizedResource, need2Apply); - - return queryResultAfterDesensitization; - - } - - public boolean allSensitiveResReqIsOk(Set sensitiveResReq, Set resAuthSet) { - if (resAuthSet.containsAll(sensitiveResReq)) { - return true; - } - log.info("sensitiveResReq:{}, resAuthSet:{}", sensitiveResReq, resAuthSet); - return false; - } - - private void doRowPermission(QueryStructReq queryStructReq, AuthorizedResourceResp authorizedResource) { - log.debug("start doRowPermission logic"); - StringJoiner joiner = new StringJoiner(" OR "); - List dimensionFilters = new ArrayList<>(); - if (!CollectionUtils.isEmpty(authorizedResource.getFilters())) { - authorizedResource.getFilters().stream() - .forEach(filter -> dimensionFilters.addAll(filter.getExpressions())); - } - - if (CollectionUtils.isEmpty(dimensionFilters)) { - log.debug("dimensionFilters is empty"); - return; - } - - dimensionFilters.stream().forEach(filter -> { - if (StringUtils.isNotEmpty(filter) && StringUtils.isNotEmpty(filter.trim())) { - joiner.add(" ( " + filter + " ) "); - } - }); - - if (StringUtils.isNotEmpty(joiner.toString())) { - log.info("before doRowPermission, queryStructReq:{}", queryStructReq); - Filter filter = new Filter("", FilterOperatorEnum.SQL_PART, joiner.toString()); - List filters = Objects.isNull(queryStructReq.getOriginalFilter()) ? new ArrayList<>() - : queryStructReq.getOriginalFilter(); - filters.add(filter); - queryStructReq.setDimensionFilters(filters); - log.info("after doRowPermission, queryStructReq:{}", queryStructReq); - } - - } - - private void doFilterCheckLogic(QueryStructReq queryStructReq, Set resAuthName, - Set sensitiveResReq) { - Set resFilterSet = queryStructUtils.getFilterResNameEnExceptInternalCol(queryStructReq); - Set need2Apply = resFilterSet.stream() - .filter(res -> !resAuthName.contains(res) && sensitiveResReq.contains(res)).collect(Collectors.toSet()); - Set nameCnSet = new HashSet<>(); - - Map modelRespMap = modelService.getModelMap(); - List modelIds = Lists.newArrayList(queryStructReq.getModelIds()); - List dimensionDescList = dimensionService.getDimensions(new MetaFilter(modelIds)); - dimensionDescList.stream().filter(dim -> need2Apply.contains(dim.getBizName())) - .forEach(dim -> nameCnSet.add(modelRespMap.get(dim.getModelId()).getName() + MINUS + dim.getName())); - - if (!CollectionUtils.isEmpty(need2Apply)) { - List admins = modelService.getModelAdmin(modelIds.get(0)); - log.info("in doFilterLogic, need2Apply:{}", need2Apply); - String message = String.format("您没有以下维度%s权限, 请联系管理员%s开通", nameCnSet, admins); - throw new InvalidPermissionException(message); - } - } - -} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/QueryController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/QueryController.java index 2eea08787..f3eb6f727 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/QueryController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/QueryController.java @@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.api.pojo.request.BatchDownloadReq; import com.tencent.supersonic.headless.api.pojo.request.DownloadStructReq; import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq; -import com.tencent.supersonic.headless.api.pojo.request.ParseSqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryItemReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; @@ -18,23 +17,19 @@ import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.ItemQueryResultResp; import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; -import com.tencent.supersonic.headless.api.pojo.response.SqlParserResp; -import com.tencent.supersonic.headless.core.pojo.QueryStatement; import com.tencent.supersonic.headless.server.service.DownloadService; import com.tencent.supersonic.headless.server.service.QueryService; +import java.util.List; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.validation.Valid; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.validation.Valid; -import java.util.List; - @RestController @RequestMapping("/api/semantic/query") @Slf4j @@ -51,7 +46,7 @@ public class QueryController { HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); - return queryService.queryBySql(querySQLReq, user); + return queryService.queryByReq(querySQLReq, user); } @PostMapping("/struct") @@ -60,7 +55,7 @@ public class QueryController { HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); QuerySqlReq querySqlReq = queryStructReq.convert(queryStructReq, true); - return queryService.queryBySql(querySqlReq, user); + return queryService.queryByReq(querySqlReq, user); } @PostMapping("/queryMetricDataById") @@ -85,19 +80,6 @@ public class QueryController { downloadService.batchDownload(batchDownloadReq, user, response); } - @PostMapping("/queryStatement") - public SemanticQueryResp queryStatement(@RequestBody QueryStatement queryStatement) throws Exception { - return queryService.queryByQueryStatement(queryStatement); - } - - @PostMapping("/struct/parse") - public SqlParserResp parseByStruct(@RequestBody ParseSqlReq parseSqlReq) throws Exception { - QueryStatement queryStatement = queryService.explain(parseSqlReq); - SqlParserResp sqlParserResp = new SqlParserResp(); - BeanUtils.copyProperties(queryStatement, sqlParserResp); - return sqlParserResp; - } - /** * queryByMultiStruct */ @@ -106,7 +88,7 @@ public class QueryController { HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); - return queryService.queryByMultiStruct(queryMultiStructReq, user); + return queryService.queryByReq(queryMultiStructReq, user); } /** @@ -132,23 +114,19 @@ public class QueryController { public ExplainResp explain(@RequestBody ExplainSqlReq explainSqlReq, HttpServletRequest request, HttpServletResponse response) throws Exception { - User user = UserHolder.findUser(request, response); String queryReqJson = JsonUtil.toString(explainSqlReq.getQueryReq()); - QueryType queryTypeEnum = explainSqlReq.getQueryTypeEnum(); - if (QueryType.SQL.equals(queryTypeEnum)) { - QuerySqlReq querySQLReq = JsonUtil.toObject(queryReqJson, QuerySqlReq.class); + if (QueryType.SQL.equals(explainSqlReq.getQueryTypeEnum())) { ExplainSqlReq explainSqlReqNew = ExplainSqlReq.builder() - .queryReq(querySQLReq) - .queryTypeEnum(queryTypeEnum).build(); + .queryReq(JsonUtil.toObject(queryReqJson, QuerySqlReq.class)) + .queryTypeEnum(explainSqlReq.getQueryTypeEnum()).build(); return queryService.explain(explainSqlReqNew, user); } - if (QueryType.STRUCT.equals(queryTypeEnum)) { - QueryStructReq queryStructReq = JsonUtil.toObject(queryReqJson, QueryStructReq.class); + if (QueryType.STRUCT.equals(explainSqlReq.getQueryTypeEnum())) { ExplainSqlReq explainSqlReqNew = ExplainSqlReq.builder() - .queryReq(queryStructReq) - .queryTypeEnum(queryTypeEnum).build(); + .queryReq(JsonUtil.toObject(queryReqJson, QueryStructReq.class)) + .queryTypeEnum(explainSqlReq.getQueryTypeEnum()).build(); return queryService.explain(explainSqlReqNew, user); } return null; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/QueryService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/QueryService.java index 3f452657c..664841991 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/QueryService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/QueryService.java @@ -3,46 +3,28 @@ package com.tencent.supersonic.headless.server.service; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq; -import com.tencent.supersonic.headless.api.pojo.request.ParseSqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryItemReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; -import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; -import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.ItemQueryResultResp; import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; -import com.tencent.supersonic.headless.core.pojo.QueryStatement; import com.tencent.supersonic.headless.server.annotation.ApiHeaderCheck; import java.util.List; import javax.servlet.http.HttpServletRequest; public interface QueryService { - SemanticQueryResp queryBySql(QuerySqlReq querySqlCmd, User user) throws Exception; - - SemanticQueryResp queryByStruct(QueryStructReq queryStructCmd, User user) throws Exception; - - SemanticQueryResp queryBySemanticQuery(SemanticQueryReq semanticQueryReq, User user) throws Exception; - - SemanticQueryResp queryByStructWithAuth(QueryStructReq queryStructCmd, User user) throws Exception; - - SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructCmd, User user) throws Exception; + SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception; SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user); - SemanticQueryResp queryByQueryStatement(QueryStatement queryStatement); - List getStatInfo(ItemUseReq itemUseCommend); ExplainResp explain(ExplainSqlReq explainSqlReq, User user) throws Exception; - QueryStatement explain(ParseSqlReq parseSqlReq) throws Exception; - @ApiHeaderCheck - ItemQueryResultResp queryMetricDataById(QueryItemReq queryApiReq, - HttpServletRequest request) throws Exception; + ItemQueryResultResp queryMetricDataById(QueryItemReq queryApiReq, HttpServletRequest request) throws Exception; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DownloadServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DownloadServiceImpl.java index 49a600aaf..83d3ef7de 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DownloadServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DownloadServiceImpl.java @@ -74,7 +74,7 @@ public class DownloadServiceImpl implements DownloadService { File file = FileUtils.createTmpFile(fileName); try { QuerySqlReq querySqlReq = downloadStructReq.convert(downloadStructReq, true); - SemanticQueryResp queryResult = (SemanticQueryResp) queryService.queryBySql(querySqlReq, user); + SemanticQueryResp queryResult = (SemanticQueryResp) queryService.queryByReq(querySqlReq, user); DataDownload dataDownload = buildDataDownload(queryResult, downloadStructReq); EasyExcel.write(file).sheet("Sheet1").head(dataDownload.getHeaders()).doWrite(dataDownload.getData()); } catch (RuntimeException e) { @@ -114,7 +114,7 @@ public class DownloadServiceImpl implements DownloadService { for (MetricSchemaResp metric : metrics) { try { DownloadStructReq downloadStructReq = buildDownloadStructReq(dimensions, metric, batchDownloadReq); - SemanticQueryResp queryResult = queryService.queryByStructWithAuth(downloadStructReq, user); + SemanticQueryResp queryResult = queryService.queryByReq(downloadStructReq, user); DataDownload dataDownload = buildDataDownload(queryResult, downloadStructReq); WriteSheet writeSheet = EasyExcel.writerSheet("Sheet" + sheetCount) .head(dataDownload.getHeaders()).build(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryServiceImpl.java index 15d6aed9c..e395828bf 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryServiceImpl.java @@ -20,7 +20,6 @@ import com.tencent.supersonic.headless.api.pojo.SingleItemQueryResult; import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq; import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaFilterReq; -import com.tencent.supersonic.headless.api.pojo.request.ParseSqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryItemReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; @@ -43,8 +42,7 @@ import com.tencent.supersonic.headless.core.parser.QueryParser; import com.tencent.supersonic.headless.core.parser.calcite.s2sql.SemanticModel; import com.tencent.supersonic.headless.core.planner.QueryPlanner; import com.tencent.supersonic.headless.core.pojo.QueryStatement; -import com.tencent.supersonic.headless.server.annotation.S2SQLDataPermission; -import com.tencent.supersonic.headless.server.annotation.StructDataPermission; +import com.tencent.supersonic.headless.server.annotation.S2DataPermission; import com.tencent.supersonic.headless.server.aspect.ApiHeaderCheckAspect; import com.tencent.supersonic.headless.server.manager.SemanticSchemaManager; import com.tencent.supersonic.headless.server.pojo.DimensionFilter; @@ -56,7 +54,6 @@ import com.tencent.supersonic.headless.server.utils.QueryReqConverter; import com.tencent.supersonic.headless.server.utils.QueryUtils; import com.tencent.supersonic.headless.server.utils.StatUtils; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -111,30 +108,40 @@ public class QueryServiceImpl implements QueryService { } @Override - @S2SQLDataPermission + @S2DataPermission @SneakyThrows - public SemanticQueryResp queryBySql(QuerySqlReq querySQLReq, User user) { - return queryBySemanticQuery(querySQLReq, user); - } - - @Override - public SemanticQueryResp queryByStruct(QueryStructReq queryStructCmd, User user) throws Exception { - return queryBySemanticQuery(queryStructCmd, user); - } - - public SemanticQueryResp queryByQueryStatement(QueryStatement queryStatement) { - - SemanticQueryResp queryResultWithColumns = null; - QueryExecutor queryExecutor = queryPlanner.route(queryStatement); - if (queryExecutor != null) { - queryResultWithColumns = queryExecutor.execute(queryStatement); - queryResultWithColumns.setSql(queryStatement.getSql()); - if (!CollectionUtils.isEmpty(queryStatement.getModelIds())) { - queryUtils.fillItemNameInfo(queryResultWithColumns, queryStatement.getModelIds()); + public SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) { + TaskStatusEnum state = TaskStatusEnum.SUCCESS; + log.info("[queryReq:{}]", queryReq); + try { + //1.initStatInfo + statUtils.initStatInfo(queryReq, user); + //2.query from cache + Object query = queryCache.query(queryReq); + if (Objects.nonNull(query)) { + return (SemanticQueryResp) query; } + StatUtils.get().setUseResultCache(false); + //3 query + QueryStatement queryStatement = buildQueryStatement(queryReq, user); + SemanticQueryResp result = query(queryStatement); + //4 reset cache and set stateInfo + Boolean setCacheSuccess = queryCache.put(queryReq, result); + if (setCacheSuccess) { + // if result is not null, update cache data + statUtils.updateResultCacheKey(queryCache.getCacheKey(queryReq)); + } + if (Objects.isNull(result)) { + state = TaskStatusEnum.ERROR; + } + return result; + } catch (Exception e) { + log.error("exception in queryByStruct, e: ", e); + state = TaskStatusEnum.ERROR; + throw e; + } finally { + statUtils.statInfo2DbAsync(state); } - return queryResultWithColumns; - } private QueryStatement buildSqlQueryStatement(QuerySqlReq querySQLReq, User user) throws Exception { @@ -150,41 +157,6 @@ public class QueryServiceImpl implements QueryService { return queryStatement; } - @Override - public SemanticQueryResp queryBySemanticQuery(SemanticQueryReq semanticQueryReq, User user) throws Exception { - TaskStatusEnum state = TaskStatusEnum.SUCCESS; - log.info("[semanticQueryReq:{}]", semanticQueryReq); - try { - //1.initStatInfo - statUtils.initStatInfo(semanticQueryReq, user); - //2.query from cache - Object query = queryCache.query(semanticQueryReq); - if (Objects.nonNull(query)) { - return (SemanticQueryResp) query; - } - StatUtils.get().setUseResultCache(false); - //3 query - QueryStatement queryStatement = buildQueryStatement(semanticQueryReq, user); - SemanticQueryResp result = query(queryStatement); - //4 reset cache and set stateInfo - Boolean setCacheSuccess = queryCache.put(semanticQueryReq, result); - if (setCacheSuccess) { - // if result is not null, update cache data - statUtils.updateResultCacheKey(queryCache.getCacheKey(semanticQueryReq)); - } - if (Objects.isNull(result)) { - state = TaskStatusEnum.ERROR; - } - return result; - } catch (Exception e) { - log.error("exception in queryByStruct, e: ", e); - state = TaskStatusEnum.ERROR; - throw e; - } finally { - statUtils.statInfo2DbAsync(state); - } - } - private QueryStatement buildQueryStatement(SemanticQueryReq semanticQueryReq, User user) throws Exception { if (semanticQueryReq instanceof QuerySqlReq) { return buildSqlQueryStatement((QuerySqlReq) semanticQueryReq, user); @@ -225,66 +197,11 @@ public class QueryServiceImpl implements QueryService { return queryUtils.sqlParserUnion(queryMultiStructReq, sqlParsers); } - @Override - @StructDataPermission - @SneakyThrows - public SemanticQueryResp queryByStructWithAuth(QueryStructReq queryStructReq, User user) { - return queryByStruct(queryStructReq, user); - } - - @Override - public SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user) - throws Exception { - TaskStatusEnum state = TaskStatusEnum.SUCCESS; - try { - //1.initStatInfo - statUtils.initStatInfo(queryMultiStructReq.getQueryStructReqs().get(0), user); - //2.query from cache - Object query = queryCache.query(queryMultiStructReq); - if (Objects.nonNull(query)) { - return (SemanticQueryResp) query; - } - StatUtils.get().setUseResultCache(false); - - //3.parse and optimizer - List sqlParsers = new ArrayList<>(); - for (QueryStructReq queryStructReq : queryMultiStructReq.getQueryStructReqs()) { - QueryStatement queryStatement = buildQueryStatement(queryStructReq, user); - queryParser.parse(queryStatement); - queryPlanner.plan(queryStatement); - sqlParsers.add(queryStatement); - } - log.info("multi sqlParser:{}", sqlParsers); - QueryStatement queryStatement = queryUtils.sqlParserUnion(queryMultiStructReq, sqlParsers); - - //4.route - QueryExecutor executor = queryPlanner.route(queryStatement); - - SemanticQueryResp semanticQueryResp = null; - if (executor != null) { - semanticQueryResp = executor.execute(queryStatement); - if (!CollectionUtils.isEmpty(queryStatement.getModelIds())) { - queryUtils.fillItemNameInfo(semanticQueryResp, queryStatement.getModelIds()); - } - } - if (Objects.isNull(semanticQueryResp)) { - state = TaskStatusEnum.ERROR; - } - return semanticQueryResp; - } catch (Exception e) { - log.error("exception in queryByMultiStruct, e: ", e); - state = TaskStatusEnum.ERROR; - throw e; - } finally { - statUtils.statInfo2DbAsync(state); - } - } - @Override @SneakyThrows public SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) { QuerySqlReq querySQLReq = buildQuerySqlReq(queryDimValueReq); - return queryBySql(querySQLReq, user); + return queryByReq(querySQLReq, user); } @Override @@ -308,22 +225,6 @@ public class QueryServiceImpl implements QueryService { return getExplainResp(queryStatement); } - @Override - public QueryStatement explain(ParseSqlReq parseSqlReq) throws Exception { - QueryStructReq queryStructCmd = new QueryStructReq(); - Set models = new HashSet<>(); - models.add(Long.valueOf(parseSqlReq.getRootPath())); - queryStructCmd.setModelIds(models); - QueryStatement queryStatement = new QueryStatement(); - queryStatement.setQueryStructReq(queryStructCmd); - queryStatement.setParseSqlReq(parseSqlReq); - queryStatement.setSql(parseSqlReq.getSql()); - queryStatement.setIsS2SQL(true); - SemanticModel semanticModel = semanticSchemaManager.get(parseSqlReq.getRootPath()); - queryStatement.setSemanticModel(semanticModel); - return plan(queryStatement); - } - @Override public ItemQueryResultResp queryMetricDataById(QueryItemReq queryItemReq, HttpServletRequest request) throws Exception { @@ -348,15 +249,14 @@ public class QueryServiceImpl implements QueryService { item.setName(metricResp.getName()); List items = item.getRelateItems(); List dimensionResps = Lists.newArrayList(); - if (!org.springframework.util.CollectionUtils.isEmpty(items)) { + if (!CollectionUtils.isEmpty(items)) { List ids = items.stream().map(Item::getId).collect(Collectors.toList()); DimensionFilter dimensionFilter = new DimensionFilter(); dimensionFilter.setIds(ids); dimensionResps = catalog.getDimensions(dimensionFilter); } QueryStructReq queryStructReq = buildQueryStructReq(dimensionResps, metricResp, dateConf, limit); - SemanticQueryResp semanticQueryResp = - queryByStruct(queryStructReq, User.getAppUser(appId)); + SemanticQueryResp semanticQueryResp = queryByReq(queryStructReq, User.getAppUser(appId)); SingleItemQueryResult apiQuerySingleResult = new SingleItemQueryResult(); apiQuerySingleResult.setItem(item); apiQuerySingleResult.setResult(semanticQueryResp); diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/DownloadServiceImplTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/DownloadServiceImplTest.java index 785bb60c5..f126b9a58 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/DownloadServiceImplTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/DownloadServiceImplTest.java @@ -32,7 +32,7 @@ class DownloadServiceImplTest { ModelService modelService = Mockito.mock(ModelService.class); QueryService queryService = Mockito.mock(QueryService.class); when(modelService.fetchModelSchema(any())).thenReturn(Lists.newArrayList(mockModelSchemaResp())); - when(queryService.queryByStruct(any(), any())).thenReturn(mockQueryResult()); + when(queryService.queryByReq(any(), any())).thenReturn(mockQueryResult()); DownloadServiceImpl downloadService = new DownloadServiceImpl(modelService, queryService); String fileName = String.format("%s_%s.xlsx", "supersonic", DateUtils.format(new Date(), DateUtils.FORMAT)); File file = FileUtils.createTmpFile(fileName); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/benchmark/CSpider.java b/launchers/standalone/src/test/java/com/tencent/supersonic/benchmark/CSpider.java deleted file mode 100644 index f85f0588b..000000000 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/benchmark/CSpider.java +++ /dev/null @@ -1,10 +0,0 @@ -package com.tencent.supersonic.benchmark; - -import org.junit.Test; - -public class CSpider { - @Test - public void case1(){ - - } -} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/BaseTest.java similarity index 97% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/BaseTest.java index 7727cb03e..d789eb40a 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/BaseTest.java @@ -1,7 +1,8 @@ -package com.tencent.supersonic.integration; +package com.tencent.supersonic.chat.integration; import static org.junit.Assert.assertEquals; +import com.tencent.supersonic.chat.integration.util.DataUtils; import com.tencent.supersonic.StandaloneLauncher; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; @@ -15,7 +16,6 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.service.ConfigService; import com.tencent.supersonic.chat.server.service.QueryService; -import com.tencent.supersonic.util.DataUtils; import java.time.LocalDate; import java.util.Set; import java.util.stream.Collectors; @@ -30,7 +30,7 @@ import org.springframework.test.context.junit4.SpringRunner; @RunWith(SpringRunner.class) @SpringBootTest(classes = StandaloneLauncher.class) @ActiveProfiles("local") -public class BaseQueryTest { +public class BaseTest { protected final int unit = 7; protected final String startDay = LocalDate.now().plusDays(-unit).toString(); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MetricInterpretTest.java similarity index 95% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MetricInterpretTest.java index 1e0b1ae41..4cb625f15 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MetricInterpretTest.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.integration; +package com.tencent.supersonic.chat.integration; import com.tencent.supersonic.StandaloneLauncher; import com.tencent.supersonic.chat.core.query.llm.analytics.LLMAnswerResp; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MetricTest.java similarity index 93% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MetricTest.java index e58e251ea..d96c882da 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MetricTest.java @@ -1,4 +1,7 @@ -package com.tencent.supersonic.integration; +package com.tencent.supersonic.chat.integration; + +import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; +import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; @@ -8,27 +11,23 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery; import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery; import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.chat.core.query.rule.metric.MetricTopNQuery; +import com.tencent.supersonic.chat.integration.util.DataUtils; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; -import com.tencent.supersonic.util.DataUtils; -import org.junit.Assert; -import org.junit.Test; - import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; - -import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; -import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM; +import org.junit.Assert; +import org.junit.Test; -public class MetricQueryTest extends BaseQueryTest { +public class MetricTest extends BaseTest { @Test - public void queryTest_metric_filter() throws Exception { + public void testMetricFilter() throws Exception { MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId); @@ -52,7 +51,7 @@ public class MetricQueryTest extends BaseQueryTest { } @Test - public void queryTest_metric_filter_with_agent() { + public void testMetricFilterWithAgent() { //agent only support METRIC_ENTITY, METRIC_FILTER MockConfiguration.mockMetricAgent(agentService); ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId()); @@ -63,7 +62,7 @@ public class MetricQueryTest extends BaseQueryTest { } @Test - public void queryTest_metric_domain() throws Exception { + public void testMetricDomain() throws Exception { MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId); @@ -83,7 +82,7 @@ public class MetricQueryTest extends BaseQueryTest { } @Test - public void queryTest_metric_model_with_agent() { + public void testMetricModelWithAgent() { //agent only support METRIC_ENTITY, METRIC_FILTER MockConfiguration.mockMetricAgent(agentService); ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId()); @@ -93,7 +92,7 @@ public class MetricQueryTest extends BaseQueryTest { } @Test - public void queryTest_metric_groupby() throws Exception { + public void testMetricGroupBy() throws Exception { QueryResult actualResult = submitNewChat("超音数各部门的访问次数", DataUtils.metricAgentId); QueryResult expectedResult = new QueryResult(); @@ -114,7 +113,7 @@ public class MetricQueryTest extends BaseQueryTest { } @Test - public void queryTest_metric_filter_compare() throws Exception { + public void testMetricFilterCompare() throws Exception { MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId); @@ -139,7 +138,7 @@ public class MetricQueryTest extends BaseQueryTest { } @Test - public void queryTest_metric_topn() throws Exception { + public void testMetricTopN() throws Exception { QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId); QueryResult expectedResult = new QueryResult(); @@ -161,7 +160,7 @@ public class MetricQueryTest extends BaseQueryTest { } @Test - public void queryTest_metric_groupby_sum() throws Exception { + public void testMetricGroupBySum() throws Exception { QueryResult actualResult = submitNewChat("超音数各部门的访问次数总和", DataUtils.metricAgentId); QueryResult expectedResult = new QueryResult(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); @@ -181,7 +180,7 @@ public class MetricQueryTest extends BaseQueryTest { } @Test - public void queryTest_metric_filter_time() throws Exception { + public void testMetricFilterTime() throws Exception { MockConfiguration.mockMetricAgent(agentService); DateFormat format = new SimpleDateFormat("yyyy-mm-dd"); DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日"); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MockConfiguration.java similarity index 93% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MockConfiguration.java index d5b2e3202..89a1f9715 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MockConfiguration.java @@ -1,15 +1,15 @@ -package com.tencent.supersonic.integration; +package com.tencent.supersonic.chat.integration; import static org.mockito.Mockito.when; import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.integration.util.DataUtils; import com.tencent.supersonic.chat.core.plugin.PluginManager; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; -import com.tencent.supersonic.util.DataUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.context.annotation.Configuration; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MultiTurnsTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MultiTurnsTest.java similarity index 97% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/MultiTurnsTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MultiTurnsTest.java index cb4ab8689..732ed1b87 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MultiTurnsTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/MultiTurnsTest.java @@ -1,7 +1,8 @@ -package com.tencent.supersonic.integration; +package com.tencent.supersonic.chat.integration; import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; +import com.tencent.supersonic.chat.integration.util.DataUtils; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery; @@ -9,13 +10,12 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.util.DataUtils; import java.text.DateFormat; import java.text.SimpleDateFormat; import org.junit.Test; import org.junit.jupiter.api.Order; -public class MultiTurnsTest extends BaseQueryTest { +public class MultiTurnsTest extends BaseTest { @Test @Order(1) diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/TagQueryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/TagTest.java similarity index 95% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/TagQueryTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/TagTest.java index ea76c91ea..2f28ff446 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/TagQueryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/TagTest.java @@ -1,5 +1,8 @@ -package com.tencent.supersonic.integration; +package com.tencent.supersonic.chat.integration; +import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; + +import com.tencent.supersonic.chat.integration.util.DataUtils; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; @@ -10,15 +13,11 @@ import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; -import com.tencent.supersonic.util.DataUtils; -import org.junit.Test; - import java.util.ArrayList; import java.util.List; +import org.junit.Test; -import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; - -public class TagQueryTest extends BaseQueryTest { +public class TagTest extends BaseTest { @Test public void queryTest_metric_tag_query() throws Exception { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/mapper/MapperTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/mapper/MapperTest.java similarity index 89% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/mapper/MapperTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/mapper/MapperTest.java index a9978d598..9fdb5762f 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/mapper/MapperTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/mapper/MapperTest.java @@ -1,5 +1,9 @@ -package com.tencent.supersonic.integration.mapper; +package com.tencent.supersonic.chat.integration.mapper; +import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; + +import com.tencent.supersonic.chat.integration.BaseTest; +import com.tencent.supersonic.chat.integration.util.DataUtils; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; @@ -9,13 +13,9 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricTagQuery; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; -import com.tencent.supersonic.integration.BaseQueryTest; -import com.tencent.supersonic.util.DataUtils; import org.junit.Test; -import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; - -public class MapperTest extends BaseQueryTest { +public class MapperTest extends BaseTest { @Test public void hanlp() throws Exception { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/model/MetricServiceImplTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/model/MetricServiceImplTest.java similarity index 96% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/model/MetricServiceImplTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/model/MetricServiceImplTest.java index 2c4479208..aa69ec648 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/model/MetricServiceImplTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/model/MetricServiceImplTest.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.integration.model; +package com.tencent.supersonic.chat.integration.model; import com.google.common.collect.Lists; import com.tencent.supersonic.StandaloneLauncher; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/BasePluginTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/plugin/BasePluginTest.java similarity index 96% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/BasePluginTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/plugin/BasePluginTest.java index 0a2ad2555..e5adb154b 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/BasePluginTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/plugin/BasePluginTest.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.integration.plugin; +package com.tencent.supersonic.chat.integration.plugin; import com.tencent.supersonic.StandaloneLauncher; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/plugin/PluginRecognizeTest.java similarity index 95% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/plugin/PluginRecognizeTest.java index 5c714bbf2..4c03786ea 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/plugin/PluginRecognizeTest.java @@ -1,5 +1,6 @@ -package com.tencent.supersonic.integration.plugin; +package com.tencent.supersonic.chat.integration.plugin; +import com.tencent.supersonic.chat.integration.util.DataUtils; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; @@ -10,8 +11,7 @@ import com.tencent.supersonic.chat.core.plugin.PluginManager; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.QueryService; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.integration.MockConfiguration; -import com.tencent.supersonic.util.DataUtils; +import com.tencent.supersonic.chat.integration.MockConfiguration; import org.junit.Assert; import org.junit.Test; import org.springframework.beans.factory.annotation.Autowired; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/util/DataUtils.java similarity index 98% rename from launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/util/DataUtils.java index cf98afc43..744ffc60f 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/integration/util/DataUtils.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.util; +package com.tencent.supersonic.chat.integration.util; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/integration/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/integration/BaseTest.java new file mode 100644 index 000000000..90bd2fb44 --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/integration/BaseTest.java @@ -0,0 +1,43 @@ +package com.tencent.supersonic.headless.integration; + +import com.tencent.supersonic.StandaloneLauncher; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; +import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; +import com.tencent.supersonic.headless.server.service.QueryService; +import java.util.HashSet; +import java.util.Set; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.junit4.SpringRunner; + +@RunWith(SpringRunner.class) +@SpringBootTest(classes = StandaloneLauncher.class) +@ActiveProfiles("local") +public class BaseTest { + + @Autowired + private QueryService queryService; + + protected SemanticQueryResp queryBySql(String sql) throws Exception { + return queryBySql(sql, User.getFakeUser()); + } + + protected SemanticQueryResp queryBySql(String sql, User user) throws Exception { + return queryService.queryByReq(buildQuerySqlReq(sql), user); + } + + protected QuerySqlReq buildQuerySqlReq(String sql) { + QuerySqlReq querySqlCmd = new QuerySqlReq(); + querySqlCmd.setSql(sql); + Set modelIds = new HashSet<>(); + modelIds.add(1L); + modelIds.add(2L); + modelIds.add(3L); + querySqlCmd.setModelIds(modelIds); + return querySqlCmd; + } + +} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/integration/QueryBySqlTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/integration/QueryBySqlTest.java new file mode 100644 index 000000000..c9b3eb41c --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/integration/QueryBySqlTest.java @@ -0,0 +1,59 @@ +package com.tencent.supersonic.headless.integration; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException; +import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; +import org.junit.Test; + +public class QueryBySqlTest extends BaseTest { + + @Test + public void testSumQuery() throws Exception { + SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 "); + + assertEquals(1, semanticQueryResp.getColumns().size()); + QueryColumn queryColumn = semanticQueryResp.getColumns().get(0); + assertEquals("访问次数", queryColumn.getName()); + assertEquals(1, semanticQueryResp.getResultList().size()); + } + + @Test + public void testGroupByQuery() throws Exception { + SemanticQueryResp result = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); + assertEquals(2, result.getColumns().size()); + QueryColumn firstColumn = result.getColumns().get(0); + QueryColumn secondColumn = result.getColumns().get(1); + assertEquals("部门", firstColumn.getName()); + assertEquals("访问次数", secondColumn.getName()); + assertEquals(4, result.getResultList().size()); + } + + @Test + public void testCacheQuery() throws Exception { + SemanticQueryResp result1 = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); + SemanticQueryResp result2 = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); + assertEquals(result1, result2); + } + + @Test + public void testBizNameQuery() throws Exception { + SemanticQueryResp result1 = queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'"); + SemanticQueryResp result2 = queryBySql("SELECT SUM(访问次数) FROM 超音数PVUV统计 WHERE 部门 ='HR'"); + assertEquals(1, result1.getColumns().size()); + assertEquals(1, result2.getColumns().size()); + assertEquals(result1.getColumns().get(0), result2.getColumns().get(0)); + assertEquals(result1.getResultList(), result2.getResultList()); + } + + @Test + public void testAuthorization() throws Exception { + User alice = new User(2L, "alice", "alice", "alice@email", 0); + assertThrows(InvalidPermissionException.class, + () -> queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'", alice)); + } + +}