(improvement)[chat] Skip the corrector for complex SQL, and do not add the HAVING field to the SELECT clause (#1754)

This commit is contained in:
lexluo09
2024-10-09 14:38:12 +08:00
committed by GitHub
parent 3ea3c93dc6
commit fc040970b2
9 changed files with 150 additions and 61 deletions

View File

@@ -12,11 +12,28 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
@Slf4j @Slf4j
public class SqlAsHelper { public class SqlAsHelper {
private static void extractAliasesFromSelect(PlainSelect plainSelect, Set<String> aliases) {
// Extract aliases from SELECT items
for (SelectItem selectItem : plainSelect.getSelectItems()) {
Alias alias = selectItem.getAlias();
if (alias != null) {
aliases.add(alias.getName());
}
}
FunctionAliasVisitor visitor = new FunctionAliasVisitor(aliases);
for (SelectItem selectItem : plainSelect.getSelectItems()) {
selectItem.accept(visitor);
}
}
public static List<String> getAsFields(String sql) { public static List<String> getAsFields(String sql) {
List<PlainSelect> plainSelectList = SqlSelectHelper.getPlainSelect(sql); List<PlainSelect> plainSelectList = SqlSelectHelper.getPlainSelect(sql);
if (CollectionUtils.isEmpty(plainSelectList)) { if (CollectionUtils.isEmpty(plainSelectList)) {
@@ -43,17 +60,14 @@ public class SqlAsHelper {
return new ArrayList<>(aliases); return new ArrayList<>(aliases);
} }
private static void extractAliasesFromSelect(PlainSelect plainSelect, Set<String> aliases) { public static Map<String, String> getFieldMapFilterByAsFields(String sql,
// Extract aliases from SELECT items Map<String, String> fieldNameMap) {
for (SelectItem selectItem : plainSelect.getSelectItems()) { // Delete aliases if they exist
Alias alias = selectItem.getAlias(); List<String> asFields = SqlAsHelper.getAsFields(sql);
if (alias != null) { Set<String> asFieldsSet = new HashSet<>(asFields);
aliases.add(alias.getName()); fieldNameMap = fieldNameMap.entrySet().stream()
} .filter(entry -> !asFieldsSet.contains(entry.getKey()))
} .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
FunctionAliasVisitor visitor = new FunctionAliasVisitor(aliases); return fieldNameMap;
for (SelectItem selectItem : plainSelect.getSelectItems()) {
selectItem.accept(visitor);
}
} }
} }

View File

@@ -43,7 +43,9 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.function.UnaryOperator; import java.util.function.UnaryOperator;
/** Sql Parser replace Helper */ /**
* Sql Parser replace Helper
*/
@Slf4j @Slf4j
public class SqlReplaceHelper { public class SqlReplaceHelper {
public static String replaceAggFields(String sql, public static String replaceAggFields(String sql,
@@ -180,6 +182,8 @@ public class SqlReplaceHelper {
return selectStatement.toString(); return selectStatement.toString();
} }
private static void replaceFieldsInPlainOneSelect(Map<String, String> fieldNameMap, private static void replaceFieldsInPlainOneSelect(Map<String, String> fieldNameMap,
boolean exactReplace, PlainSelect plainSelect) { boolean exactReplace, PlainSelect plainSelect) {
// 1. replace where fields // 1. replace where fields

View File

@@ -2,11 +2,14 @@ package com.tencent.supersonic.common.jsqlparser;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.PlainSelect;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import java.util.List; import java.util.List;
/** Sql Parser valid Helper */ /**
* Sql Parser valid Helper
*/
@Slf4j @Slf4j
public class SqlValidHelper { public class SqlValidHelper {
@@ -75,4 +78,9 @@ public class SqlValidHelper {
return false; return false;
} }
} }
public static boolean isComplexSQL(String sql) {
List<PlainSelect> plainSelect = SqlSelectHelper.getPlainSelect(sql);
return !CollectionUtils.isEmpty(plainSelect) && plainSelect.size() >= 2;
}
} }

View File

@@ -61,4 +61,54 @@ class SqlValidHelperTest {
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false); Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
} }
@Test
void testIsComplexSQL() {
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), false);
sql1 = "SELECT\n" + " COUNT(部门)\n" + "FROM\n" + " (\n" + " SELECT\n" + " 部门,\n"
+ " COUNT(DISTINCT 用户) AS UV\n" + " FROM\n" + " 超音数数据集\n"
+ " WHERE\n" + " 数据日期 >= '2024-09-08'\n"
+ " AND 数据日期 <= '2024-10-08'\n" + " GROUP BY\n" + " 部门\n"
+ " HAVING\n" + " COUNT(DISTINCT 用户) > 2\n" + " ) AS subquery";
Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true);
sql1 = "SELECT\n" + " COUNT(部门)\n" + "FROM\n" + " (\n" + " SELECT\n" + " 部门,\n"
+ " COUNT(DISTINCT 用户) AS UV\n" + " FROM\n" + " 超音数数据集\n"
+ " WHERE\n" + " 数据日期 >= '2024-09-08'\n"
+ " AND 数据日期 <= '2024-10-08'\n" + " GROUP BY\n" + " 部门\n"
+ " HAVING\n" + " COUNT(DISTINCT 用户) > 2\n" + " ) AS subquery";
Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true);
sql1 = " SELECT\n" + " `t6`.`sys_imp_date`,\n" + " `t5`.`department`,\n"
+ " `t6`.`s2_pv_uv_statis_pv` AS `pv`\n" + " FROM\n" + " (\n"
+ " SELECT\n" + " `user_name`,\n" + " `department`\n"
+ " FROM\n" + " `s2_user_department`\n" + " ) AS `t5`\n"
+ " LEFT JOIN (\n" + " SELECT\n"
+ " 1 AS `s2_pv_uv_statis_pv`,\n"
+ " `imp_date` AS `sys_imp_date`,\n" + " `user_name`\n"
+ " FROM\n" + " `s2_pv_uv_statis`\n"
+ " ) AS `t6` ON `t5`.`user_name` = `t6`.`user_name`";
Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true);
sql1 = " SELECT\n" + " `t6`.`sys_imp_date`,\n" + " `t5`.`department`,\n"
+ " `t6`.`s2_pv_uv_statis_pv` AS `pv`\n" + " FROM\n" + " (\n"
+ " SELECT\n" + " `user_name`,\n" + " `department`\n"
+ " FROM\n" + " `s2_user_department`\n" + " ) AS `t5`\n"
+ " LEFT JOIN (\n" + " SELECT\n"
+ " 1 AS `s2_pv_uv_statis_pv`,\n"
+ " `imp_date` AS `sys_imp_date`,\n" + " `user_name`\n"
+ " FROM\n" + " `s2_pv_uv_statis`\n"
+ " ) AS `t6` ON `t5`.`user_name` = `t6`.`user_name`";
Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true);
sql1 = "WITH\n" + " UserCounts AS (\n" + " SELECT\n" + " 部门,\n"
+ " COUNT(DISTINCT 用户) AS UV\n" + " FROM\n" + " 超音数数据集\n"
+ " WHERE\n" + " 数据日期 >= '2024-09-08'\n"
+ " AND 数据日期 <= '2024-10-08'\n" + " GROUP BY\n" + " 部门\n" + " )\n"
+ "SELECT\n" + " COUNT(*)\n" + "FROM\n" + " UserCounts\n" + "WHERE\n"
+ " count(UV) > 2";
Assert.assertEquals(SqlValidHelper.isComplexSQL(sql1), true);
}
} }

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -8,12 +9,18 @@ import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
/** Verify whether the SQL aggregate function is missing. If it is missing, fill it in. */ /**
* Verify whether the SQL aggregate function is missing. If it is missing, fill it in.
*/
@Slf4j @Slf4j
public class AggCorrector extends BaseSemanticCorrector { public class AggCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
return;
}
addAggregate(chatQueryContext, semanticParseInfo); addAggregate(chatQueryContext, semanticParseInfo);
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -24,6 +25,10 @@ public class GroupByCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
return;
}
Boolean needAddGroupBy = needAddGroupBy(chatQueryContext, semanticParseInfo); Boolean needAddGroupBy = needAddGroupBy(chatQueryContext, semanticParseInfo);
if (!needAddGroupBy) { if (!needAddGroupBy) {
return; return;

View File

@@ -3,14 +3,11 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
@@ -23,18 +20,8 @@ public class HavingCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
// add aggregate to all metric // add aggregate to all metric
addHaving(chatQueryContext, semanticParseInfo); addHaving(chatQueryContext, semanticParseInfo);
// decide whether add having expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo =
environment.getProperty("s2.corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo)
&& Boolean.parseBoolean(correctorAdditionalInfo)) {
addHavingToSelect(semanticParseInfo);
}
} }
private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
@@ -30,6 +31,9 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
return;
}
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL); List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL); List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add // If the number of aggregated fields is equal to the number of queried fields, do not add

View File

@@ -274,9 +274,9 @@ public class QueryReqConverter {
String sql = viewQueryParam.getSql(); String sql = viewQueryParam.getSql();
for (MetricTable metricTable : viewQueryParam.getTables()) { for (MetricTable metricTable : viewQueryParam.getTables()) {
Set<String> measures = new HashSet<>(); Set<String> measures = new HashSet<>();
Map<String, String> replaces = new HashMap<>(); Map<String, String> replaces = generateDerivedMetric(semanticSchemaResp, aggOption,
generateDerivedMetric(semanticSchemaResp, aggOption, metricTable.getMetrics(), metricTable.getMetrics(), metricTable.getDimensions(), measures);
metricTable.getDimensions(), measures, replaces);
if (!CollectionUtils.isEmpty(replaces)) { if (!CollectionUtils.isEmpty(replaces)) {
// metricTable sql use measures replace metric // metricTable sql use measures replace metric
sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces); sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces);
@@ -295,49 +295,59 @@ public class QueryReqConverter {
viewQueryParam.setSql(sql); viewQueryParam.setSql(sql);
} }
private void generateDerivedMetric(SemanticSchemaResp semanticSchemaResp, AggOption aggOption, private Map<String, String> generateDerivedMetric(SemanticSchemaResp semanticSchemaResp,
List<String> metrics, List<String> dimensions, Set<String> measures, AggOption aggOption, List<String> metrics, List<String> dimensions,
Map<String, String> replaces) { Set<String> measures) {
Map<String, String> result = new HashMap<>();
List<MetricSchemaResp> metricResps = semanticSchemaResp.getMetrics(); List<MetricSchemaResp> metricResps = semanticSchemaResp.getMetrics();
List<DimSchemaResp> dimensionResps = semanticSchemaResp.getDimensions(); List<DimSchemaResp> dimensionResps = semanticSchemaResp.getDimensions();
// check metrics has derived
if (!metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType // Check if any metric is derived
.isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()))) { boolean hasDerivedMetrics =
return; metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType
.isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()));
if (!hasDerivedMetrics) {
return result;
} }
log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics); log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics);
Set<String> allFields = new HashSet<>(); Set<String> allFields = new HashSet<>();
Map<String, Measure> allMeasures = new HashMap<>(); Map<String, Measure> allMeasures = new HashMap<>();
semanticSchemaResp.getModelResps().forEach(modelResp -> { semanticSchemaResp.getModelResps().forEach(modelResp -> {
allFields.addAll(modelResp.getFieldList()); allFields.addAll(modelResp.getFieldList());
if (Objects.nonNull(modelResp.getModelDetail().getMeasures())) { if (modelResp.getModelDetail().getMeasures() != null) {
modelResp.getModelDetail().getMeasures().stream() modelResp.getModelDetail().getMeasures()
.forEach(mm -> allMeasures.put(mm.getBizName(), mm)); .forEach(measure -> allMeasures.put(measure.getBizName(), measure));
} }
}); });
Set<String> deriveDimension = new HashSet<>();
Set<String> deriveMetric = new HashSet<>(); Set<String> derivedDimensions = new HashSet<>();
Map<String, String> visitedMetric = new HashMap<>(); Set<String> derivedMetrics = new HashSet<>();
if (!CollectionUtils.isEmpty(metricResps)) { Map<String, String> visitedMetrics = new HashMap<>();
for (MetricResp metricResp : metricResps) {
if (metrics.contains(metricResp.getBizName())) { for (MetricResp metricResp : metricResps) {
if (MetricType.isDerived(metricResp.getMetricDefineType(), if (metrics.contains(metricResp.getBizName())) {
metricResp.getMetricDefineByMeasureParams())) { boolean isDerived = MetricType.isDerived(metricResp.getMetricDefineType(),
String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields, metricResp.getMetricDefineByMeasureParams());
allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp), if (isDerived) {
metricResp.getMetricDefineType(), aggOption, visitedMetric, String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields,
deriveMetric, deriveDimension); allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp),
replaces.put(metricResp.getBizName(), expr); metricResp.getMetricDefineType(), aggOption, visitedMetrics,
log.debug("derived metric {}->{}", metricResp.getBizName(), expr); derivedMetrics, derivedDimensions);
} else { result.put(metricResp.getBizName(), expr);
measures.add(metricResp.getBizName()); log.debug("derived metric {}->{}", metricResp.getBizName(), expr);
} } else {
measures.add(metricResp.getBizName());
} }
} }
} }
measures.addAll(deriveMetric);
deriveDimension.stream().filter(d -> !dimensions.contains(d)) measures.addAll(derivedMetrics);
.forEach(d -> dimensions.add(d)); derivedDimensions.stream().filter(dimension -> !dimensions.contains(dimension))
.forEach(dimensions::add);
return result;
} }
private String getDefaultModel(SemanticSchemaResp semanticSchemaResp, List<String> dimensions) { private String getDefaultModel(SemanticSchemaResp semanticSchemaResp, List<String> dimensions) {