2 Commits

Author SHA1 Message Date
ChPi
5df0b87da9 (fix)(headless) correct SQL when WHERE condition contains only column without function (#2360)
Some checks failed
supersonic CentOS CI / build (21) (push) Has been cancelled
supersonic mac CI / build (21) (push) Has been cancelled
supersonic ubuntu CI / build (21) (push) Has been cancelled
supersonic windows CI / build (21) (push) Has been cancelled
2025-08-17 18:10:30 +08:00
ChPi
ab24b1777a (fix)(common) prevent NullPointerException for jsonFormat (#2365) 2025-08-17 16:21:36 +08:00
5 changed files with 55 additions and 14 deletions

View File

@@ -16,16 +16,21 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.*;
@Slf4j
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
private List<Expression> waitingForAdds = new ArrayList<>();
private Set<String> fieldNames;
private Map<String, String> fieldNameMap = new HashMap<>();
private static Set<String> HAVING_AGG_TYPES = Set.of("SUM", "AVG", "MAX", "MIN", "COUNT");
public FiledFilterReplaceVisitor(Map<String, String> fieldNameMap) {
this.fieldNameMap = fieldNameMap;
this.fieldNames = fieldNameMap.keySet();
}
public FiledFilterReplaceVisitor(Set<String> fieldNames) {
this.fieldNames = fieldNames;
@@ -82,7 +87,22 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
Expression leftExpression = comparisonOperator.getLeftExpression();
if (!(leftExpression instanceof Function)) {
return result;
if (leftExpression instanceof Column) {
Column leftColumn = (Column) leftExpression;
String agg = fieldNameMap.get(leftColumn.getColumnName());
if (agg != null && HAVING_AGG_TYPES.contains(agg.toUpperCase())) {
Expression expression = parseCondExpression(comparisonOperator, condExpr);
if (Objects.nonNull(expression)) {
result.add(expression);
return result;
} else {
return null;
}
}
return result;
} else {
return result;
}
}
Function leftFunction = (Function) leftExpression;
@@ -102,14 +122,24 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
return null;
}
Expression expression = parseCondExpression(comparisonOperator, condExpr);
if (Objects.nonNull(expression)) {
result.add(expression);
return result;
} else {
return null;
}
}
private Expression parseCondExpression(ComparisonOperator comparisonOperator, String condExpr) {
try {
String comparisonOperatorStr = comparisonOperator.toString();
ComparisonOperator parsedExpression =
(ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression());
comparisonOperator.setRightExpression(parsedExpression.getRightExpression());
comparisonOperator.setASTNode(parsedExpression.getASTNode());
result.add(CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr));
return result;
return CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr);
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}

View File

@@ -309,7 +309,7 @@ public class SqlAddHelper {
}
}
public static String addHaving(String sql, Set<String> fieldNames) {
public static String addHaving(String sql, Map<String, String> fieldNames) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {

View File

@@ -30,7 +30,7 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses());
if (modelConfig.getJsonFormat()) {
if (modelConfig.getJsonFormat() != null && modelConfig.getJsonFormat()) {
openAiChatModelBuilder.strictJsonSchema(true)
.responseFormat(modelConfig.getJsonFormatType());
}

View File

@@ -338,8 +338,8 @@ class SqlAddHelperTest {
List<String> groupByFields = new ArrayList<>();
groupByFields.add("department");
Set<String> fieldNames = new HashSet<>();
fieldNames.add("pv");
Map<String, String> fieldNames = new HashMap<>();
fieldNames.put("pv", "sum");
String replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
@@ -355,6 +355,14 @@ class SqlAddHelperTest {
Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
sql = "SELECT 数据日期,访问用户数 FROM 超音数数据集 WHERE 访问次数 > 10 GROUP BY 数据日期";
fieldNames.put("访问次数", "sum");
replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
Assert.assertEquals("SELECT 数据日期, 访问用户数 FROM 超音数数据集 GROUP BY 数据日期 HAVING 访问次数 > 10",
replaceSql);
}
@Test

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
@@ -11,7 +12,8 @@ import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Set;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/** Perform SQL corrections on the "Having" section in S2SQL. */
@@ -29,8 +31,9 @@ public class HavingCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
Map<String, String> metrics = semanticSchema.getMetrics(dataSet).stream()
.collect(Collectors.toMap(SchemaElement::getName,
e -> Optional.ofNullable(e.getDefaultAgg()).orElse("")));
if (CollectionUtils.isEmpty(metrics)) {
return;