From 5df0b87da92e71ca130f20a0a8ddc463d0523fce Mon Sep 17 00:00:00 2001 From: ChPi Date: Sun, 17 Aug 2025 18:10:30 +0800 Subject: [PATCH] (fix)(headless) correct SQL when WHERE condition contains only column without function (#2360) --- .../jsqlparser/FiledFilterReplaceVisitor.java | 44 ++++++++++++++++--- .../common/jsqlparser/SqlAddHelper.java | 2 +- .../common/jsqlparser/SqlAddHelperTest.java | 12 ++++- .../chat/corrector/HavingCorrector.java | 9 ++-- 4 files changed, 54 insertions(+), 13 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FiledFilterReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FiledFilterReplaceVisitor.java index 8758cf9c1..4ce1c4662 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FiledFilterReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FiledFilterReplaceVisitor.java @@ -16,16 +16,21 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import org.apache.commons.collections.CollectionUtils; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.Set; +import java.util.*; @Slf4j public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { private List waitingForAdds = new ArrayList<>(); private Set fieldNames; + private Map fieldNameMap = new HashMap<>(); + + private static Set HAVING_AGG_TYPES = Set.of("SUM", "AVG", "MAX", "MIN", "COUNT"); + + public FiledFilterReplaceVisitor(Map fieldNameMap) { + this.fieldNameMap = fieldNameMap; + this.fieldNames = fieldNameMap.keySet(); + } public FiledFilterReplaceVisitor(Set fieldNames) { this.fieldNames = fieldNames; @@ -82,7 +87,22 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { Expression leftExpression = comparisonOperator.getLeftExpression(); if (!(leftExpression instanceof Function)) { - return result; + if (leftExpression instanceof Column) { + Column leftColumn = (Column) leftExpression; + String agg = fieldNameMap.get(leftColumn.getColumnName()); + if (agg != null && HAVING_AGG_TYPES.contains(agg.toUpperCase())) { + Expression expression = parseCondExpression(comparisonOperator, condExpr); + if (Objects.nonNull(expression)) { + result.add(expression); + return result; + } else { + return null; + } + } + return result; + } else { + return result; + } } Function leftFunction = (Function) leftExpression; @@ -102,14 +122,24 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { return null; } + Expression expression = parseCondExpression(comparisonOperator, condExpr); + if (Objects.nonNull(expression)) { + result.add(expression); + return result; + } else { + return null; + } + } + + private Expression parseCondExpression(ComparisonOperator comparisonOperator, String condExpr) { try { + String comparisonOperatorStr = comparisonOperator.toString(); ComparisonOperator parsedExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr); comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression()); comparisonOperator.setRightExpression(parsedExpression.getRightExpression()); comparisonOperator.setASTNode(parsedExpression.getASTNode()); - result.add(CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr)); - return result; + return CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr); } catch (JSQLParserException e) { log.error("JSQLParserException", e); } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelper.java index c90f89695..642a6a860 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelper.java @@ -309,7 +309,7 @@ public class SqlAddHelper { } } - public static String addHaving(String sql, Set fieldNames) { + public static String addHaving(String sql, Map fieldNames) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java index 02a92d9c6..54cdae7ba 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java @@ -338,8 +338,8 @@ class SqlAddHelperTest { List groupByFields = new ArrayList<>(); groupByFields.add("department"); - Set fieldNames = new HashSet<>(); - fieldNames.add("pv"); + Map fieldNames = new HashMap<>(); + fieldNames.put("pv", "sum"); String replaceSql = SqlAddHelper.addHaving(sql, fieldNames); @@ -355,6 +355,14 @@ class SqlAddHelperTest { Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", replaceSql); + + sql = "SELECT 数据日期,访问用户数 FROM 超音数数据集 WHERE 访问次数 > 10 GROUP BY 数据日期"; + + fieldNames.put("访问次数", "sum"); + replaceSql = SqlAddHelper.addHaving(sql, fieldNames); + + Assert.assertEquals("SELECT 数据日期, 访问用户数 FROM 超音数数据集 GROUP BY 数据日期 HAVING 访问次数 > 10", + replaceSql); } @Test diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java index 183b95b47..85cfd64ee 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.chat.ChatQueryContext; @@ -11,7 +12,8 @@ import net.sf.jsqlparser.expression.Expression; import org.springframework.util.CollectionUtils; import java.util.List; -import java.util.Set; +import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; /** Perform SQL corrections on the "Having" section in S2SQL. */ @@ -29,8 +31,9 @@ public class HavingCorrector extends BaseSemanticCorrector { SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); - Set metrics = semanticSchema.getMetrics(dataSet).stream() - .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet()); + Map metrics = semanticSchema.getMetrics(dataSet).stream() + .collect(Collectors.toMap(SchemaElement::getName, + e -> Optional.ofNullable(e.getDefaultAgg()).orElse(""))); if (CollectionUtils.isEmpty(metrics)) { return;