mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(fix)(headless) correct SQL when WHERE condition contains only column without function (#2360)
This commit is contained in:
@@ -16,16 +16,21 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
|||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
private List<Expression> waitingForAdds = new ArrayList<>();
|
||||||
private Set<String> fieldNames;
|
private Set<String> fieldNames;
|
||||||
|
private Map<String, String> fieldNameMap = new HashMap<>();
|
||||||
|
|
||||||
|
private static Set<String> HAVING_AGG_TYPES = Set.of("SUM", "AVG", "MAX", "MIN", "COUNT");
|
||||||
|
|
||||||
|
public FiledFilterReplaceVisitor(Map<String, String> fieldNameMap) {
|
||||||
|
this.fieldNameMap = fieldNameMap;
|
||||||
|
this.fieldNames = fieldNameMap.keySet();
|
||||||
|
}
|
||||||
|
|
||||||
public FiledFilterReplaceVisitor(Set<String> fieldNames) {
|
public FiledFilterReplaceVisitor(Set<String> fieldNames) {
|
||||||
this.fieldNames = fieldNames;
|
this.fieldNames = fieldNames;
|
||||||
@@ -82,7 +87,22 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
|||||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||||
|
|
||||||
if (!(leftExpression instanceof Function)) {
|
if (!(leftExpression instanceof Function)) {
|
||||||
return result;
|
if (leftExpression instanceof Column) {
|
||||||
|
Column leftColumn = (Column) leftExpression;
|
||||||
|
String agg = fieldNameMap.get(leftColumn.getColumnName());
|
||||||
|
if (agg != null && HAVING_AGG_TYPES.contains(agg.toUpperCase())) {
|
||||||
|
Expression expression = parseCondExpression(comparisonOperator, condExpr);
|
||||||
|
if (Objects.nonNull(expression)) {
|
||||||
|
result.add(expression);
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Function leftFunction = (Function) leftExpression;
|
Function leftFunction = (Function) leftExpression;
|
||||||
@@ -102,14 +122,24 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Expression expression = parseCondExpression(comparisonOperator, condExpr);
|
||||||
|
if (Objects.nonNull(expression)) {
|
||||||
|
result.add(expression);
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Expression parseCondExpression(ComparisonOperator comparisonOperator, String condExpr) {
|
||||||
try {
|
try {
|
||||||
|
String comparisonOperatorStr = comparisonOperator.toString();
|
||||||
ComparisonOperator parsedExpression =
|
ComparisonOperator parsedExpression =
|
||||||
(ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
(ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||||
comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression());
|
comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression());
|
||||||
comparisonOperator.setRightExpression(parsedExpression.getRightExpression());
|
comparisonOperator.setRightExpression(parsedExpression.getRightExpression());
|
||||||
comparisonOperator.setASTNode(parsedExpression.getASTNode());
|
comparisonOperator.setASTNode(parsedExpression.getASTNode());
|
||||||
result.add(CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr));
|
return CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr);
|
||||||
return result;
|
|
||||||
} catch (JSQLParserException e) {
|
} catch (JSQLParserException e) {
|
||||||
log.error("JSQLParserException", e);
|
log.error("JSQLParserException", e);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ public class SqlAddHelper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String addHaving(String sql, Set<String> fieldNames) {
|
public static String addHaving(String sql, Map<String, String> fieldNames) {
|
||||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
|
|
||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
|
|||||||
@@ -338,8 +338,8 @@ class SqlAddHelperTest {
|
|||||||
List<String> groupByFields = new ArrayList<>();
|
List<String> groupByFields = new ArrayList<>();
|
||||||
groupByFields.add("department");
|
groupByFields.add("department");
|
||||||
|
|
||||||
Set<String> fieldNames = new HashSet<>();
|
Map<String, String> fieldNames = new HashMap<>();
|
||||||
fieldNames.add("pv");
|
fieldNames.put("pv", "sum");
|
||||||
|
|
||||||
String replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
|
String replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
|
||||||
|
|
||||||
@@ -355,6 +355,14 @@ class SqlAddHelperTest {
|
|||||||
Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||||
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
||||||
replaceSql);
|
replaceSql);
|
||||||
|
|
||||||
|
sql = "SELECT 数据日期,访问用户数 FROM 超音数数据集 WHERE 访问次数 > 10 GROUP BY 数据日期";
|
||||||
|
|
||||||
|
fieldNames.put("访问次数", "sum");
|
||||||
|
replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
|
||||||
|
|
||||||
|
Assert.assertEquals("SELECT 数据日期, 访问用户数 FROM 超音数数据集 GROUP BY 数据日期 HAVING 访问次数 > 10",
|
||||||
|
replaceSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.corrector;
|
|||||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
@@ -11,7 +12,8 @@ import net.sf.jsqlparser.expression.Expression;
|
|||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/** Perform SQL corrections on the "Having" section in S2SQL. */
|
/** Perform SQL corrections on the "Having" section in S2SQL. */
|
||||||
@@ -29,8 +31,9 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
|
|
||||||
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
Map<String, String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
.collect(Collectors.toMap(SchemaElement::getName,
|
||||||
|
e -> Optional.ofNullable(e.getDefaultAgg()).orElse("")));
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(metrics)) {
|
if (CollectionUtils.isEmpty(metrics)) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
Reference in New Issue
Block a user