mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
Compare commits
2 Commits
ff76f8edbd
...
5df0b87da9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5df0b87da9 | ||
|
|
ab24b1777a |
@@ -16,16 +16,21 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
||||
private Set<String> fieldNames;
|
||||
private Map<String, String> fieldNameMap = new HashMap<>();
|
||||
|
||||
private static Set<String> HAVING_AGG_TYPES = Set.of("SUM", "AVG", "MAX", "MIN", "COUNT");
|
||||
|
||||
public FiledFilterReplaceVisitor(Map<String, String> fieldNameMap) {
|
||||
this.fieldNameMap = fieldNameMap;
|
||||
this.fieldNames = fieldNameMap.keySet();
|
||||
}
|
||||
|
||||
public FiledFilterReplaceVisitor(Set<String> fieldNames) {
|
||||
this.fieldNames = fieldNames;
|
||||
@@ -82,7 +87,22 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
|
||||
if (!(leftExpression instanceof Function)) {
|
||||
return result;
|
||||
if (leftExpression instanceof Column) {
|
||||
Column leftColumn = (Column) leftExpression;
|
||||
String agg = fieldNameMap.get(leftColumn.getColumnName());
|
||||
if (agg != null && HAVING_AGG_TYPES.contains(agg.toUpperCase())) {
|
||||
Expression expression = parseCondExpression(comparisonOperator, condExpr);
|
||||
if (Objects.nonNull(expression)) {
|
||||
result.add(expression);
|
||||
return result;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
Function leftFunction = (Function) leftExpression;
|
||||
@@ -102,14 +122,24 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
return null;
|
||||
}
|
||||
|
||||
Expression expression = parseCondExpression(comparisonOperator, condExpr);
|
||||
if (Objects.nonNull(expression)) {
|
||||
result.add(expression);
|
||||
return result;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Expression parseCondExpression(ComparisonOperator comparisonOperator, String condExpr) {
|
||||
try {
|
||||
String comparisonOperatorStr = comparisonOperator.toString();
|
||||
ComparisonOperator parsedExpression =
|
||||
(ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(parsedExpression.getRightExpression());
|
||||
comparisonOperator.setASTNode(parsedExpression.getASTNode());
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr));
|
||||
return result;
|
||||
return CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
|
||||
@@ -309,7 +309,7 @@ public class SqlAddHelper {
|
||||
}
|
||||
}
|
||||
|
||||
public static String addHaving(String sql, Set<String> fieldNames) {
|
||||
public static String addHaving(String sql, Map<String, String> fieldNames) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
|
||||
if (!(selectStatement instanceof PlainSelect)) {
|
||||
|
||||
@@ -30,7 +30,7 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses());
|
||||
if (modelConfig.getJsonFormat()) {
|
||||
if (modelConfig.getJsonFormat() != null && modelConfig.getJsonFormat()) {
|
||||
openAiChatModelBuilder.strictJsonSchema(true)
|
||||
.responseFormat(modelConfig.getJsonFormatType());
|
||||
}
|
||||
|
||||
@@ -338,8 +338,8 @@ class SqlAddHelperTest {
|
||||
List<String> groupByFields = new ArrayList<>();
|
||||
groupByFields.add("department");
|
||||
|
||||
Set<String> fieldNames = new HashSet<>();
|
||||
fieldNames.add("pv");
|
||||
Map<String, String> fieldNames = new HashMap<>();
|
||||
fieldNames.put("pv", "sum");
|
||||
|
||||
String replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
|
||||
|
||||
@@ -355,6 +355,14 @@ class SqlAddHelperTest {
|
||||
Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "SELECT 数据日期,访问用户数 FROM 超音数数据集 WHERE 访问次数 > 10 GROUP BY 数据日期";
|
||||
|
||||
fieldNames.put("访问次数", "sum");
|
||||
replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
|
||||
|
||||
Assert.assertEquals("SELECT 数据日期, 访问用户数 FROM 超音数数据集 GROUP BY 数据日期 HAVING 访问次数 > 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.corrector;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -11,7 +12,8 @@ import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** Perform SQL corrections on the "Having" section in S2SQL. */
|
||||
@@ -29,8 +31,9 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
Map<String, String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getName,
|
||||
e -> Optional.ofNullable(e.getDefaultAgg()).orElse("")));
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user