(improvement)(headless) add SqlEvaluation interface (#1150)

This commit is contained in:
jipeli
2024-06-14 22:41:04 +08:00
committed by GitHub
parent 0509242242
commit f8b818cb82
5 changed files with 70 additions and 34 deletions

View File

@@ -37,6 +37,7 @@ public class SemanticParseInfo {
private List<SchemaElementMatch> elementMatches = new ArrayList<>(); private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>(); private Map<String, Object> properties = new HashMap<>();
private SqlInfo sqlInfo = new SqlInfo(); private SqlInfo sqlInfo = new SqlInfo();
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
private QueryType queryType = QueryType.ID; private QueryType queryType = QueryType.ID;
private EntityInfo entityInfo; private EntityInfo entityInfo;
private String textInfo; private String textInfo;

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.headless.api.pojo;
import lombok.Data;
@Data
public class SqlEvaluation {
private Boolean isValidated;
private String validateMsg;
}

View File

@@ -37,8 +37,8 @@ public class SqlQueryApiController {
@PostMapping("/sql") @PostMapping("/sql")
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, public Object queryBySql(@RequestBody QuerySqlReq querySqlReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql(); String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql)); querySqlReq.setSql(StringUtil.replaceBackticks(sql));
@@ -48,8 +48,8 @@ public class SqlQueryApiController {
@PostMapping("/sqls") @PostMapping("/sqls")
public Object queryBySqls(@RequestBody QuerySqlsReq querySqlsReq, public Object queryBySqls(@RequestBody QuerySqlsReq querySqlsReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
List<SemanticQueryReq> semanticQueryReqs = querySqlsReq.getSqls() List<SemanticQueryReq> semanticQueryReqs = querySqlsReq.getSqls()
.stream().map(sql -> { .stream().map(sql -> {
@@ -60,22 +60,22 @@ public class SqlQueryApiController {
return querySqlReq; return querySqlReq;
}).collect(Collectors.toList()); }).collect(Collectors.toList());
List<CompletableFuture<SemanticQueryResp>> futures = semanticQueryReqs.stream() List<CompletableFuture<SemanticQueryResp>> futures = semanticQueryReqs.stream()
.map(querySqlReq -> CompletableFuture.supplyAsync(() -> { .map(querySqlReq -> CompletableFuture.supplyAsync(() -> {
try { try {
return queryService.queryByReq(querySqlReq, user); return queryService.queryByReq(querySqlReq, user);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
return new SemanticQueryResp(); return new SemanticQueryResp();
} }
})) }))
.collect(Collectors.toList()); .collect(Collectors.toList());
return futures.stream().map(CompletableFuture::join).collect(Collectors.toList()); return futures.stream().map(CompletableFuture::join).collect(Collectors.toList());
} }
@PostMapping("/sqlsWithException") @PostMapping("/sqlsWithException")
public Object queryBySqlsWithException(@RequestBody QuerySqlsReq querySqlsReq, public Object queryBySqlsWithException(@RequestBody QuerySqlsReq querySqlsReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
List<SemanticQueryReq> semanticQueryReqs = querySqlsReq.getSqls() List<SemanticQueryReq> semanticQueryReqs = querySqlsReq.getSqls()
.stream().map(sql -> { .stream().map(sql -> {
@@ -97,4 +97,14 @@ public class SqlQueryApiController {
return semanticQueryRespList; return semanticQueryRespList;
} }
@PostMapping("/validate")
public Object validate(@RequestBody QuerySqlReq querySqlReq,
HttpServletRequest request,
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
return chatQueryService.validate(querySqlReq, user);
}
} }

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
@@ -32,5 +33,7 @@ public interface ChatQueryService {
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception; Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
void correct(QuerySqlReq querySqlReq, User user); void correct(QuerySqlReq querySqlReq, User user);
SqlEvaluation validate(QuerySqlReq querySqlReq, User user);
} }

View File

@@ -20,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.CostType; import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
@@ -182,7 +183,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
SemanticParseInfo parseInfo, User user) throws Exception { SemanticParseInfo parseInfo, User user) throws Exception {
SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user); SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user);
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
if (queryResp != null) { if (queryResp != null) {
@@ -330,10 +331,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap, Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList, List<FieldExpression> fieldExpressionList,
List<Expression> addConditions, List<Expression> addConditions,
Set<String> removeFieldNames) { Set<String> removeFieldNames) {
if (Objects.isNull(queryData.getDateInfo())) { if (Objects.isNull(queryData.getDateInfo())) {
return; return;
} }
@@ -357,7 +358,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
for (QueryFilter queryFilter : queryData.getDimensionFilters()) { for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE) if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
&& FilterOperatorEnum.LIKE.getValue().toLowerCase().equals( && FilterOperatorEnum.LIKE.getValue().toLowerCase().equals(
fieldExpression.getOperator().toLowerCase())) { fieldExpression.getOperator().toLowerCase())) {
Map<String, String> replaceMap = new HashMap<>(); Map<String, String> replaceMap = new HashMap<>();
String preValue = fieldExpression.getFieldValue().toString(); String preValue = fieldExpression.getFieldValue().toString();
String curValue = queryFilter.getValue().toString(); String curValue = queryFilter.getValue().toString();
@@ -377,8 +378,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private <T extends ComparisonOperator> void addTimeFilters(String date, private <T extends ComparisonOperator> void addTimeFilters(String date,
T comparisonExpression, T comparisonExpression,
List<Expression> addConditions) { List<Expression> addConditions) {
Column column = new Column(TimeDimensionEnum.DAY.getChName()); Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(date); StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column); comparisonExpression.setLeftExpression(column);
@@ -387,10 +388,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private void updateFilters(List<FieldExpression> fieldExpressionList, private void updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters, Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters, Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions, List<Expression> addConditions,
Set<String> removeFieldNames) { Set<String> removeFieldNames) {
if (CollectionUtils.isEmpty(metricFilters)) { if (CollectionUtils.isEmpty(metricFilters)) {
return; return;
} }
@@ -426,9 +427,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
// add in condition to sql where condition // add in condition to sql where condition
private void addWhereInFilters(QueryFilter dslQueryFilter, private void addWhereInFilters(QueryFilter dslQueryFilter,
InExpression inExpression, InExpression inExpression,
Set<QueryFilter> contextMetricFilters, Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) { List<Expression> addConditions) {
Column column = new Column(dslQueryFilter.getName()); Column column = new Column(dslQueryFilter.getName());
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>(); ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
List<String> valueList = JsonUtil.toList( List<String> valueList = JsonUtil.toList(
@@ -453,9 +454,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
// add where filter // add where filter
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter, private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression, T comparisonExpression,
Set<QueryFilter> contextMetricFilters, Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) { List<Expression> addConditions) {
String columnName = dslQueryFilter.getName(); String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")"; columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
@@ -592,6 +593,17 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
public void correct(QuerySqlReq querySqlReq, User user) { public void correct(QuerySqlReq querySqlReq, User user) {
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectS2SQL());
}
@Override
public SqlEvaluation validate(QuerySqlReq querySqlReq, User user) {
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
return semanticParseInfo.getSqlEvaluation();
}
private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) {
QueryContext queryCtx = new QueryContext(); QueryContext queryCtx = new QueryContext();
SemanticSchema semanticSchema = semanticService.getSemanticSchema(); SemanticSchema semanticSchema = semanticService.getSemanticSchema();
queryCtx.setSemanticSchema(semanticSchema); queryCtx.setSemanticSchema(semanticSchema);
@@ -615,7 +627,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
}); });
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectS2SQL()); log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectS2SQL());
querySqlReq.setSql(sqlInfo.getCorrectS2SQL()); return semanticParseInfo;
} }
} }