diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java index 0707981f0..b9e7b9574 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java @@ -37,6 +37,7 @@ public class SemanticParseInfo { private List elementMatches = new ArrayList<>(); private Map properties = new HashMap<>(); private SqlInfo sqlInfo = new SqlInfo(); + private SqlEvaluation sqlEvaluation = new SqlEvaluation(); private QueryType queryType = QueryType.ID; private EntityInfo entityInfo; private String textInfo; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlEvaluation.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlEvaluation.java new file mode 100644 index 000000000..26eb564b2 --- /dev/null +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlEvaluation.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.headless.api.pojo; + +import lombok.Data; + +@Data +public class SqlEvaluation { + + private Boolean isValidated; + private String validateMsg; +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java index a3aac42cb..e0d8b6a92 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java @@ -37,8 +37,8 @@ public class SqlQueryApiController { @PostMapping("/sql") public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, - HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); String sql = querySqlReq.getSql(); querySqlReq.setSql(StringUtil.replaceBackticks(sql)); @@ -48,8 +48,8 @@ public class SqlQueryApiController { @PostMapping("/sqls") public Object queryBySqls(@RequestBody QuerySqlsReq querySqlsReq, - HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); List semanticQueryReqs = querySqlsReq.getSqls() .stream().map(sql -> { @@ -60,22 +60,22 @@ public class SqlQueryApiController { return querySqlReq; }).collect(Collectors.toList()); List> futures = semanticQueryReqs.stream() - .map(querySqlReq -> CompletableFuture.supplyAsync(() -> { - try { - return queryService.queryByReq(querySqlReq, user); - } catch (Exception e) { - e.printStackTrace(); - return new SemanticQueryResp(); - } - })) - .collect(Collectors.toList()); + .map(querySqlReq -> CompletableFuture.supplyAsync(() -> { + try { + return queryService.queryByReq(querySqlReq, user); + } catch (Exception e) { + e.printStackTrace(); + return new SemanticQueryResp(); + } + })) + .collect(Collectors.toList()); return futures.stream().map(CompletableFuture::join).collect(Collectors.toList()); } @PostMapping("/sqlsWithException") public Object queryBySqlsWithException(@RequestBody QuerySqlsReq querySqlsReq, - HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); List semanticQueryReqs = querySqlsReq.getSqls() .stream().map(sql -> { @@ -97,4 +97,14 @@ public class SqlQueryApiController { return semanticQueryRespList; } + @PostMapping("/validate") + public Object validate(@RequestBody QuerySqlReq querySqlReq, + HttpServletRequest request, + HttpServletResponse response) throws Exception { + User user = UserHolder.findUser(request, response); + String sql = querySqlReq.getSql(); + querySqlReq.setSql(StringUtil.replaceBackticks(sql)); + return chatQueryService.validate(querySqlReq, user); + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java index cb160b384..7a4d3823a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.server.service; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; @@ -32,5 +33,7 @@ public interface ChatQueryService { Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception; void correct(QuerySqlReq querySqlReq, User user); + + SqlEvaluation validate(QuerySqlReq querySqlReq, User user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index bdb4cc787..e274d643b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -20,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.enums.CostType; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; @@ -182,7 +183,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private QueryResult doExecution(SemanticQueryReq semanticQueryReq, - SemanticParseInfo parseInfo, User user) throws Exception { + SemanticParseInfo parseInfo, User user) throws Exception { SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user); QueryResult queryResult = new QueryResult(); if (queryResp != null) { @@ -330,10 +331,10 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, - Map> filedNameToValueMap, - List fieldExpressionList, - List addConditions, - Set removeFieldNames) { + Map> filedNameToValueMap, + List fieldExpressionList, + List addConditions, + Set removeFieldNames) { if (Objects.isNull(queryData.getDateInfo())) { return; } @@ -357,7 +358,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { for (QueryFilter queryFilter : queryData.getDimensionFilters()) { if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE) && FilterOperatorEnum.LIKE.getValue().toLowerCase().equals( - fieldExpression.getOperator().toLowerCase())) { + fieldExpression.getOperator().toLowerCase())) { Map replaceMap = new HashMap<>(); String preValue = fieldExpression.getFieldValue().toString(); String curValue = queryFilter.getValue().toString(); @@ -377,8 +378,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void addTimeFilters(String date, - T comparisonExpression, - List addConditions) { + T comparisonExpression, + List addConditions) { Column column = new Column(TimeDimensionEnum.DAY.getChName()); StringValue stringValue = new StringValue(date); comparisonExpression.setLeftExpression(column); @@ -387,10 +388,10 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void updateFilters(List fieldExpressionList, - Set metricFilters, - Set contextMetricFilters, - List addConditions, - Set removeFieldNames) { + Set metricFilters, + Set contextMetricFilters, + List addConditions, + Set removeFieldNames) { if (CollectionUtils.isEmpty(metricFilters)) { return; } @@ -426,9 +427,9 @@ public class ChatQueryServiceImpl implements ChatQueryService { // add in condition to sql where condition private void addWhereInFilters(QueryFilter dslQueryFilter, - InExpression inExpression, - Set contextMetricFilters, - List addConditions) { + InExpression inExpression, + Set contextMetricFilters, + List addConditions) { Column column = new Column(dslQueryFilter.getName()); ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>(); List valueList = JsonUtil.toList( @@ -453,9 +454,9 @@ public class ChatQueryServiceImpl implements ChatQueryService { // add where filter private void addWhereFilters(QueryFilter dslQueryFilter, - T comparisonExpression, - Set contextMetricFilters, - List addConditions) { + T comparisonExpression, + Set contextMetricFilters, + List addConditions) { String columnName = dslQueryFilter.getName(); if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")"; @@ -592,6 +593,17 @@ public class ChatQueryServiceImpl implements ChatQueryService { } public void correct(QuerySqlReq querySqlReq, User user) { + SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user); + querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + } + + @Override + public SqlEvaluation validate(QuerySqlReq querySqlReq, User user) { + SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user); + return semanticParseInfo.getSqlEvaluation(); + } + + private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) { QueryContext queryCtx = new QueryContext(); SemanticSchema semanticSchema = semanticService.getSemanticSchema(); queryCtx.setSemanticSchema(semanticSchema); @@ -615,7 +627,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { } }); log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectS2SQL()); - querySqlReq.setSql(sqlInfo.getCorrectS2SQL()); + return semanticParseInfo; } }