diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java index c82b540ad..5e2c46f4d 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java @@ -6,6 +6,7 @@ import javax.servlet.http.HttpServletResponse; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.util.StringUtil; +import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; @@ -38,7 +39,7 @@ public class SqlQueryApiController { @PostMapping("/sql") public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); String sql = querySqlReq.getSql(); querySqlReq.setSql(StringUtil.replaceBackticks(sql)); @@ -48,7 +49,7 @@ public class SqlQueryApiController { @PostMapping("/sqls") public Object queryBySqls(@RequestBody QuerySqlsReq querySqlsReq, HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); List semanticQueryReqs = querySqlsReq.getSqls().stream().map(sql -> { QuerySqlReq querySqlReq = new QuerySqlReq(); @@ -72,7 +73,7 @@ public class SqlQueryApiController { @PostMapping("/sqlsWithException") public Object queryBySqlsWithException(@RequestBody QuerySqlsReq querySqlsReq, - HttpServletRequest request, HttpServletResponse response) throws Exception { + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); List semanticQueryReqs = querySqlsReq.getSqls().stream().map(sql -> { QuerySqlReq querySqlReq = new QuerySqlReq(); @@ -96,10 +97,34 @@ public class SqlQueryApiController { @PostMapping("/validate") public Object validate(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); String sql = querySqlReq.getSql(); querySqlReq.setSql(StringUtil.replaceBackticks(sql)); return chatLayerService.validate(querySqlReq, user); } + + @PostMapping("/validateAndQuery") + public Object validateAndQuery(@RequestBody QuerySqlsReq querySqlsReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { + User user = UserHolder.findUser(request, response); + List convert = convert(querySqlsReq); + for (QuerySqlReq querySqlReq : convert) { + SqlEvaluation validate = chatLayerService.validate(querySqlReq, user); + if (!validate.getIsValidated()) { + throw new Exception(validate.getValidateMsg()); + } + } + return queryBySqls(querySqlsReq, request, response); + } + + private List convert(QuerySqlsReq querySqlsReq) { + return querySqlsReq.getSqls().stream().map(sql -> { + QuerySqlReq querySqlReq = new QuerySqlReq(); + BeanUtils.copyProperties(querySqlsReq, querySqlReq); + querySqlReq.setSql(StringUtil.replaceBackticks(sql)); + return querySqlReq; + }).collect(Collectors.toList()); + } + }