From fc040970b2e5282c76d25efb50a0c19c36bbc6f7 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:38:12 +0800 Subject: [PATCH] (improvement)[chat] Skip the corrector for complex SQL, and do not add the HAVING field to the SELECT clause (#1754) --- .../common/jsqlparser/SqlAsHelper.java | 38 +++++++--- .../common/jsqlparser/SqlReplaceHelper.java | 6 +- .../common/jsqlparser/SqlValidHelper.java | 10 ++- .../common/jsqlparser/SqlValidHelperTest.java | 50 ++++++++++++ .../headless/chat/corrector/AggCorrector.java | 9 ++- .../chat/corrector/GroupByCorrector.java | 5 ++ .../chat/corrector/HavingCorrector.java | 13 ---- .../chat/corrector/SelectCorrector.java | 4 + .../server/utils/QueryReqConverter.java | 76 +++++++++++-------- 9 files changed, 150 insertions(+), 61 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelper.java index a93639120..1b940d27c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelper.java @@ -12,11 +12,28 @@ import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; @Slf4j public class SqlAsHelper { + private static void extractAliasesFromSelect(PlainSelect plainSelect, Set aliases) { + // Extract aliases from SELECT items + for (SelectItem selectItem : plainSelect.getSelectItems()) { + Alias alias = selectItem.getAlias(); + if (alias != null) { + aliases.add(alias.getName()); + } + } + FunctionAliasVisitor visitor = new FunctionAliasVisitor(aliases); + for (SelectItem selectItem : plainSelect.getSelectItems()) { + selectItem.accept(visitor); + } + } + + public static List getAsFields(String sql) { List plainSelectList = SqlSelectHelper.getPlainSelect(sql); if (CollectionUtils.isEmpty(plainSelectList)) { @@ -43,17 +60,14 @@ public class SqlAsHelper { return new ArrayList<>(aliases); } - private static void extractAliasesFromSelect(PlainSelect plainSelect, Set aliases) { - // Extract aliases from SELECT items - for (SelectItem selectItem : plainSelect.getSelectItems()) { - Alias alias = selectItem.getAlias(); - if (alias != null) { - aliases.add(alias.getName()); - } - } - FunctionAliasVisitor visitor = new FunctionAliasVisitor(aliases); - for (SelectItem selectItem : plainSelect.getSelectItems()) { - selectItem.accept(visitor); - } + public static Map getFieldMapFilterByAsFields(String sql, + Map fieldNameMap) { + // Delete aliases if they exist + List asFields = SqlAsHelper.getAsFields(sql); + Set asFieldsSet = new HashSet<>(asFields); + fieldNameMap = fieldNameMap.entrySet().stream() + .filter(entry -> !asFieldsSet.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + return fieldNameMap; } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index 32086230f..3f69ad0e9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -43,7 +43,9 @@ import java.util.Objects; import java.util.Set; import java.util.function.UnaryOperator; -/** Sql Parser replace Helper */ +/** + * Sql Parser replace Helper + */ @Slf4j public class SqlReplaceHelper { public static String replaceAggFields(String sql, @@ -180,6 +182,8 @@ public class SqlReplaceHelper { return selectStatement.toString(); } + + private static void replaceFieldsInPlainOneSelect(Map fieldNameMap, boolean exactReplace, PlainSelect plainSelect) { // 1. replace where fields diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java index 638fec08f..c9b45d79f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java @@ -2,11 +2,14 @@ package com.tencent.supersonic.common.jsqlparser; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.select.PlainSelect; import org.apache.commons.collections.CollectionUtils; import java.util.List; -/** Sql Parser valid Helper */ +/** + * Sql Parser valid Helper + */ @Slf4j public class SqlValidHelper { @@ -75,4 +78,9 @@ public class SqlValidHelper { return false; } } + + public static boolean isComplexSQL(String sql) { + List plainSelect = SqlSelectHelper.getPlainSelect(sql); + return !CollectionUtils.isEmpty(plainSelect) && plainSelect.size() >= 2; + } } diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java index e455078bb..b058d5a5f 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java @@ -61,4 +61,54 @@ class SqlValidHelperTest { Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false); } + + @Test + void testIsComplexSQL() { + String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2"; + Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), false); + + sql1 = "SELECT\n" + " COUNT(部门)\n" + "FROM\n" + " (\n" + " SELECT\n" + " 部门,\n" + + " COUNT(DISTINCT 用户) AS UV\n" + " FROM\n" + " 超音数数据集\n" + + " WHERE\n" + " 数据日期 >= '2024-09-08'\n" + + " AND 数据日期 <= '2024-10-08'\n" + " GROUP BY\n" + " 部门\n" + + " HAVING\n" + " COUNT(DISTINCT 用户) > 2\n" + " ) AS subquery"; + Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true); + + sql1 = "SELECT\n" + " COUNT(部门)\n" + "FROM\n" + " (\n" + " SELECT\n" + " 部门,\n" + + " COUNT(DISTINCT 用户) AS UV\n" + " FROM\n" + " 超音数数据集\n" + + " WHERE\n" + " 数据日期 >= '2024-09-08'\n" + + " AND 数据日期 <= '2024-10-08'\n" + " GROUP BY\n" + " 部门\n" + + " HAVING\n" + " COUNT(DISTINCT 用户) > 2\n" + " ) AS subquery"; + Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true); + + sql1 = " SELECT\n" + " `t6`.`sys_imp_date`,\n" + " `t5`.`department`,\n" + + " `t6`.`s2_pv_uv_statis_pv` AS `pv`\n" + " FROM\n" + " (\n" + + " SELECT\n" + " `user_name`,\n" + " `department`\n" + + " FROM\n" + " `s2_user_department`\n" + " ) AS `t5`\n" + + " LEFT JOIN (\n" + " SELECT\n" + + " 1 AS `s2_pv_uv_statis_pv`,\n" + + " `imp_date` AS `sys_imp_date`,\n" + " `user_name`\n" + + " FROM\n" + " `s2_pv_uv_statis`\n" + + " ) AS `t6` ON `t5`.`user_name` = `t6`.`user_name`"; + Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true); + + sql1 = " SELECT\n" + " `t6`.`sys_imp_date`,\n" + " `t5`.`department`,\n" + + " `t6`.`s2_pv_uv_statis_pv` AS `pv`\n" + " FROM\n" + " (\n" + + " SELECT\n" + " `user_name`,\n" + " `department`\n" + + " FROM\n" + " `s2_user_department`\n" + " ) AS `t5`\n" + + " LEFT JOIN (\n" + " SELECT\n" + + " 1 AS `s2_pv_uv_statis_pv`,\n" + + " `imp_date` AS `sys_imp_date`,\n" + " `user_name`\n" + + " FROM\n" + " `s2_pv_uv_statis`\n" + + " ) AS `t6` ON `t5`.`user_name` = `t6`.`user_name`"; + Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true); + + sql1 = "WITH\n" + " UserCounts AS (\n" + " SELECT\n" + " 部门,\n" + + " COUNT(DISTINCT 用户) AS UV\n" + " FROM\n" + " 超音数数据集\n" + + " WHERE\n" + " 数据日期 >= '2024-09-08'\n" + + " AND 数据日期 <= '2024-10-08'\n" + " GROUP BY\n" + " 部门\n" + " )\n" + + "SELECT\n" + " COUNT(*)\n" + "FROM\n" + " UserCounts\n" + "WHERE\n" + + " count(UV) > 2"; + Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true); + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java index 33b9e98a5..ed8f17ccd 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; @@ -8,12 +9,18 @@ import org.springframework.util.CollectionUtils; import java.util.List; -/** Verify whether the SQL aggregate function is missing. If it is missing, fill it in. */ +/** + * Verify whether the SQL aggregate function is missing. If it is missing, fill it in. + */ @Slf4j public class AggCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); + if (SqlValidHelper.isComplexSQL(correctS2SQL)) { + return; + } addAggregate(chatQueryContext, semanticParseInfo); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java index 1dd74fc27..eb7a4da70 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; @@ -24,6 +25,10 @@ public class GroupByCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); + if (SqlValidHelper.isComplexSQL(correctS2SQL)) { + return; + } Boolean needAddGroupBy = needAddGroupBy(chatQueryContext, semanticParseInfo); if (!needAddGroupBy) { return; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java index cac0ef368..183b95b47 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java @@ -3,14 +3,11 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; -import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; -import org.apache.commons.lang3.StringUtils; -import org.springframework.core.env.Environment; import org.springframework.util.CollectionUtils; import java.util.List; @@ -23,18 +20,8 @@ public class HavingCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - // add aggregate to all metric addHaving(chatQueryContext, semanticParseInfo); - - // decide whether add having expression field to select - Environment environment = ContextUtils.getBean(Environment.class); - String correctorAdditionalInfo = - environment.getProperty("s2.corrector.additional.information"); - if (StringUtils.isNotBlank(correctorAdditionalInfo) - && Boolean.parseBoolean(correctorAdditionalInfo)) { - addHavingToSelect(semanticParseInfo); - } } private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java index e70d5d13c..9ffa308d9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java @@ -4,6 +4,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; @@ -30,6 +31,9 @@ public class SelectCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); + if (SqlValidHelper.isComplexSQL(correctS2SQL)) { + return; + } List aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL); List selectFields = SqlSelectHelper.getSelectFields(correctS2SQL); // If the number of aggregated fields is equal to the number of queried fields, do not add diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java index faea0945a..acd3e6ac2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java @@ -274,9 +274,9 @@ public class QueryReqConverter { String sql = viewQueryParam.getSql(); for (MetricTable metricTable : viewQueryParam.getTables()) { Set measures = new HashSet<>(); - Map replaces = new HashMap<>(); - generateDerivedMetric(semanticSchemaResp, aggOption, metricTable.getMetrics(), - metricTable.getDimensions(), measures, replaces); + Map replaces = generateDerivedMetric(semanticSchemaResp, aggOption, + metricTable.getMetrics(), metricTable.getDimensions(), measures); + if (!CollectionUtils.isEmpty(replaces)) { // metricTable sql use measures replace metric sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces); @@ -295,49 +295,59 @@ public class QueryReqConverter { viewQueryParam.setSql(sql); } - private void generateDerivedMetric(SemanticSchemaResp semanticSchemaResp, AggOption aggOption, - List metrics, List dimensions, Set measures, - Map replaces) { + private Map generateDerivedMetric(SemanticSchemaResp semanticSchemaResp, + AggOption aggOption, List metrics, List dimensions, + Set measures) { + Map result = new HashMap<>(); List metricResps = semanticSchemaResp.getMetrics(); List dimensionResps = semanticSchemaResp.getDimensions(); - // check metrics has derived - if (!metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType - .isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()))) { - return; + + // Check if any metric is derived + boolean hasDerivedMetrics = + metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType + .isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams())); + if (!hasDerivedMetrics) { + return result; } + log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics); + Set allFields = new HashSet<>(); Map allMeasures = new HashMap<>(); semanticSchemaResp.getModelResps().forEach(modelResp -> { allFields.addAll(modelResp.getFieldList()); - if (Objects.nonNull(modelResp.getModelDetail().getMeasures())) { - modelResp.getModelDetail().getMeasures().stream() - .forEach(mm -> allMeasures.put(mm.getBizName(), mm)); + if (modelResp.getModelDetail().getMeasures() != null) { + modelResp.getModelDetail().getMeasures() + .forEach(measure -> allMeasures.put(measure.getBizName(), measure)); } }); - Set deriveDimension = new HashSet<>(); - Set deriveMetric = new HashSet<>(); - Map visitedMetric = new HashMap<>(); - if (!CollectionUtils.isEmpty(metricResps)) { - for (MetricResp metricResp : metricResps) { - if (metrics.contains(metricResp.getBizName())) { - if (MetricType.isDerived(metricResp.getMetricDefineType(), - metricResp.getMetricDefineByMeasureParams())) { - String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields, - allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp), - metricResp.getMetricDefineType(), aggOption, visitedMetric, - deriveMetric, deriveDimension); - replaces.put(metricResp.getBizName(), expr); - log.debug("derived metric {}->{}", metricResp.getBizName(), expr); - } else { - measures.add(metricResp.getBizName()); - } + + Set derivedDimensions = new HashSet<>(); + Set derivedMetrics = new HashSet<>(); + Map visitedMetrics = new HashMap<>(); + + for (MetricResp metricResp : metricResps) { + if (metrics.contains(metricResp.getBizName())) { + boolean isDerived = MetricType.isDerived(metricResp.getMetricDefineType(), + metricResp.getMetricDefineByMeasureParams()); + if (isDerived) { + String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields, + allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp), + metricResp.getMetricDefineType(), aggOption, visitedMetrics, + derivedMetrics, derivedDimensions); + result.put(metricResp.getBizName(), expr); + log.debug("derived metric {}->{}", metricResp.getBizName(), expr); + } else { + measures.add(metricResp.getBizName()); } } } - measures.addAll(deriveMetric); - deriveDimension.stream().filter(d -> !dimensions.contains(d)) - .forEach(d -> dimensions.add(d)); + + measures.addAll(derivedMetrics); + derivedDimensions.stream().filter(dimension -> !dimensions.contains(dimension)) + .forEach(dimensions::add); + + return result; } private String getDefaultModel(SemanticSchemaResp semanticSchemaResp, List dimensions) {