mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) add default aggregate to all metric and add group by to dimension and add metric filter in having (#150)
This commit is contained in:
@@ -15,6 +15,7 @@ import lombok.NoArgsConstructor;
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class SchemaElement implements Serializable {
|
||||
|
||||
private Long model;
|
||||
private Long id;
|
||||
private String name;
|
||||
@@ -26,6 +27,8 @@ public class SchemaElement implements Serializable {
|
||||
|
||||
private List<SchemaValueMap> schemaValueMaps;
|
||||
|
||||
private String defaultAgg;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
@@ -5,13 +5,18 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
@@ -38,4 +43,20 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return result;
|
||||
}
|
||||
|
||||
protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||
semanticCorrectInfo.setSql(replaceFields);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalAfterCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
if (!SqlParserSelectHelper.hasAggregateFunction(sql)) {
|
||||
return;
|
||||
}
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
|
||||
semanticCorrectInfo.setSql(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,13 +1,16 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -17,10 +20,11 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalCorrector extends BaseSemanticCorrector {
|
||||
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
|
||||
replaceAlias(semanticCorrectInfo);
|
||||
@@ -33,12 +37,26 @@ public class GlobalCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
//add aggregate to all metric
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
|
||||
if (SqlParserSelectHelper.hasGroupBy(semanticCorrectInfo.getSql())) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Map<String, String> metricToAggregate = semanticSchema.getMetrics(modelId).stream()
|
||||
.map(schemaElement -> {
|
||||
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||
}
|
||||
return schemaElement;
|
||||
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
|
||||
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
|
||||
String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate);
|
||||
semanticCorrectInfo.setSql(aggregateSql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
@@ -1,7 +1,17 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
@@ -9,7 +19,25 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
|
||||
//add aggregate to all metric
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
|
||||
.filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName()))
|
||||
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
Set<String> groupByFields = selectFields.stream().filter(field -> dimensions.contains(field))
|
||||
.collect(Collectors.toSet());
|
||||
semanticCorrectInfo.setSql(SqlParserUpdateHelper.addGroupBy(sql, groupByFields));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
@@ -9,6 +16,22 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
|
||||
//add aggregate to all metric
|
||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
|
||||
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlParserUpdateHelper.addHaving(semanticCorrectInfo.getSql(), metrics);
|
||||
semanticCorrectInfo.setSql(havingSql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
@@ -19,28 +10,6 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
super.correct(semanticCorrectInfo);
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
|
||||
if (SqlParserSelectHelper.hasAggregateFunction(sql)) {
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
|
||||
semanticCorrectInfo.setSql(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||
semanticCorrectInfo.setSql(replaceFields);
|
||||
addFieldsToSelect(semanticCorrectInfo, sql);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,10 +128,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
|
||||
public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
|
||||
|
||||
String correctorSql = semanticCorrectInfo.getPreSql();
|
||||
if (StringUtils.isEmpty(correctorSql)) {
|
||||
correctorSql = semanticCorrectInfo.getSql();
|
||||
}
|
||||
String correctorSql = semanticCorrectInfo.getSql();
|
||||
parseInfo.getSqlInfo().setLogicSql(correctorSql);
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
||||
//set dataInfo
|
||||
|
||||
@@ -120,7 +120,7 @@ public class SemanticService {
|
||||
|
||||
return entityInfo;
|
||||
} catch (Exception e) {
|
||||
log.error("setMaintModel error {}", e);
|
||||
log.error("setMainModel error {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,19 +2,19 @@ package com.tencent.supersonic.knowledge.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
|
||||
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
|
||||
@@ -22,7 +22,6 @@ import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||
import com.tencent.supersonic.semantic.query.service.QueryService;
|
||||
import com.tencent.supersonic.semantic.query.service.SchemaService;
|
||||
import java.util.List;
|
||||
@@ -33,7 +32,6 @@ import lombok.extern.slf4j.Slf4j;
|
||||
public class LocalSemanticLayer extends BaseSemanticLayer {
|
||||
|
||||
private SchemaService schemaService;
|
||||
private ModelService modelService;
|
||||
private DimensionService dimensionService;
|
||||
private MetricService metricService;
|
||||
private QueryService queryService;
|
||||
|
||||
@@ -53,6 +53,7 @@ public class ModelSchemaBuilder {
|
||||
.type(SchemaElementType.METRIC)
|
||||
.useCnt(metric.getUseCnt())
|
||||
.alias(alias)
|
||||
.defaultAgg(metric.getDefaultAgg())
|
||||
.build();
|
||||
metrics.add(metricToAdd);
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
@Override
|
||||
public void visit(MinorThan expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
List<Expression> expressions = parserFilter(expr, " 1 < 2 ");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
@@ -39,7 +39,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
@Override
|
||||
public void visit(EqualsTo expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
List<Expression> expressions = parserFilter(expr, " 1 = 1 ");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
@@ -47,7 +47,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
@Override
|
||||
public void visit(MinorThanEquals expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
List<Expression> expressions = parserFilter(expr, " 1 <= 1 ");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
@@ -56,7 +56,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThan expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
List<Expression> expressions = parserFilter(expr, " 2 > 1 ");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
@@ -64,7 +64,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThanEquals expr) {
|
||||
List<Expression> expressions = parserFilter(expr);
|
||||
List<Expression> expressions = parserFilter(expr, " 1 >= 1 ");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
@@ -75,7 +75,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
}
|
||||
|
||||
|
||||
public List<Expression> parserFilter(ComparisonOperator comparisonOperator) {
|
||||
public List<Expression> parserFilter(ComparisonOperator comparisonOperator, String condExpr) {
|
||||
List<Expression> result = new ArrayList<>();
|
||||
String toString = comparisonOperator.toString();
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
@@ -97,7 +97,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 ");
|
||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
comparisonOperator.setLeftExpression(expression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(expression.getRightExpression());
|
||||
comparisonOperator.setASTNode(expression.getASTNode());
|
||||
|
||||
@@ -10,6 +10,9 @@ import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.XorExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
@@ -281,10 +284,6 @@ public class SqlParserUpdateHelper {
|
||||
}
|
||||
|
||||
public static String addAggregateToField(String sql, Map<String, String> fieldNameToAggregate) {
|
||||
if (SqlParserSelectHelper.hasGroupBy(sql)) {
|
||||
return sql;
|
||||
}
|
||||
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
@@ -296,13 +295,15 @@ public class SqlParserUpdateHelper {
|
||||
public void visit(PlainSelect plainSelect) {
|
||||
addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate);
|
||||
addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate);
|
||||
addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate);
|
||||
addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate);
|
||||
}
|
||||
});
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
public static String addGroupBy(String sql, List<String> groupByFields) {
|
||||
if (SqlParserSelectHelper.hasGroupBy(sql)) {
|
||||
public static String addGroupBy(String sql, Set<String> groupByFields) {
|
||||
if (SqlParserSelectHelper.hasGroupBy(sql) || CollectionUtils.isEmpty(groupByFields)) {
|
||||
return sql;
|
||||
}
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
@@ -327,9 +328,8 @@ public class SqlParserUpdateHelper {
|
||||
if (selectItem instanceof SelectExpressionItem) {
|
||||
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
|
||||
Expression expression = selectExpressionItem.getExpression();
|
||||
String columnName = ((Column) expression).getColumnName();
|
||||
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
|
||||
if (Objects.isNull(function)) {
|
||||
Function function = getFunction(expression, fieldNameToAggregate);
|
||||
if (function == null) {
|
||||
continue;
|
||||
}
|
||||
selectExpressionItem.setExpression(function);
|
||||
@@ -344,18 +344,102 @@ public class SqlParserUpdateHelper {
|
||||
}
|
||||
for (OrderByElement orderByElement : orderByElements) {
|
||||
Expression expression = orderByElement.getExpression();
|
||||
String columnName = ((Column) expression).getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
continue;
|
||||
}
|
||||
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
|
||||
if (Objects.isNull(function)) {
|
||||
Function function = getFunction(expression, fieldNameToAggregate);
|
||||
if (function == null) {
|
||||
continue;
|
||||
}
|
||||
orderByElement.setExpression(function);
|
||||
}
|
||||
}
|
||||
|
||||
private static void addAggregateToGroupByItems(GroupByElement groupByElement,
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
if (groupByElement == null) {
|
||||
return;
|
||||
}
|
||||
for (Expression expression : groupByElement.getGroupByExpressions()) {
|
||||
Function function = getFunction(expression, fieldNameToAggregate);
|
||||
if (function == null) {
|
||||
continue;
|
||||
}
|
||||
groupByElement.addGroupByExpression(function);
|
||||
}
|
||||
}
|
||||
|
||||
private static void addAggregateToWhereItems(Expression whereExpression, Map<String, String> fieldNameToAggregate) {
|
||||
if (whereExpression == null) {
|
||||
return;
|
||||
}
|
||||
modifyWhereExpression(whereExpression, fieldNameToAggregate);
|
||||
}
|
||||
|
||||
private static void modifyWhereExpression(Expression whereExpression,
|
||||
Map<String, String> fieldNameToAggregate) {
|
||||
if (isLogicExpression(whereExpression)) {
|
||||
AndExpression andExpression = (AndExpression) whereExpression;
|
||||
Expression leftExpression = andExpression.getLeftExpression();
|
||||
Expression rightExpression = andExpression.getRightExpression();
|
||||
if (isLogicExpression(leftExpression)) {
|
||||
modifyWhereExpression(leftExpression, fieldNameToAggregate);
|
||||
} else {
|
||||
setAggToFunction(leftExpression, fieldNameToAggregate);
|
||||
}
|
||||
if (isLogicExpression(rightExpression)) {
|
||||
modifyWhereExpression(rightExpression, fieldNameToAggregate);
|
||||
} else {
|
||||
setAggToFunction(rightExpression, fieldNameToAggregate);
|
||||
}
|
||||
setAggToFunction(rightExpression, fieldNameToAggregate);
|
||||
} else {
|
||||
setAggToFunction(whereExpression, fieldNameToAggregate);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean isLogicExpression(Expression whereExpression) {
|
||||
return whereExpression instanceof AndExpression || (whereExpression instanceof OrExpression
|
||||
|| (whereExpression instanceof XorExpression));
|
||||
}
|
||||
|
||||
|
||||
private static void setAggToFunction(Expression expression, Map<String, String> fieldNameToAggregate) {
|
||||
if (!(expression instanceof ComparisonOperator)) {
|
||||
return;
|
||||
}
|
||||
ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
|
||||
if (comparisonOperator.getRightExpression() instanceof Column) {
|
||||
String columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName();
|
||||
Function function = getFunction(comparisonOperator.getRightExpression(),
|
||||
fieldNameToAggregate.get(columnName));
|
||||
if (Objects.nonNull(function)) {
|
||||
comparisonOperator.setRightExpression(function);
|
||||
}
|
||||
}
|
||||
if (comparisonOperator.getLeftExpression() instanceof Column) {
|
||||
String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName();
|
||||
Function function = getFunction(comparisonOperator.getLeftExpression(),
|
||||
fieldNameToAggregate.get(columnName));
|
||||
if (Objects.nonNull(function)) {
|
||||
comparisonOperator.setLeftExpression(function);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static Function getFunction(Expression expression, Map<String, String> fieldNameToAggregate) {
|
||||
if (!(expression instanceof Column)) {
|
||||
return null;
|
||||
}
|
||||
String columnName = ((Column) expression).getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
return null;
|
||||
}
|
||||
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
|
||||
if (Objects.isNull(function)) {
|
||||
return null;
|
||||
}
|
||||
return function;
|
||||
}
|
||||
|
||||
private static Function getFunction(Expression expression, String aggregateName) {
|
||||
if (StringUtils.isEmpty(aggregateName)) {
|
||||
return null;
|
||||
|
||||
@@ -301,7 +301,7 @@ class SqlParserUpdateHelperTest {
|
||||
Map<String, String> filedNameToAggregate = new HashMap<>();
|
||||
filedNameToAggregate.put("pv", "sum");
|
||||
|
||||
List<String> groupByFields = new ArrayList<>();
|
||||
Set<String> groupByFields = new HashSet<>();
|
||||
groupByFields.add("department");
|
||||
|
||||
String replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
@@ -311,6 +311,66 @@ class SqlParserUpdateHelperTest {
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 "
|
||||
+ "order by pv desc limit 10";
|
||||
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "select department, pv from t_1 where pv >1 order by pv desc limit 10";
|
||||
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 "
|
||||
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "select department, pv from t_1 where sum(pv) >1 order by pv desc limit 10";
|
||||
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 "
|
||||
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
|
||||
sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and sum(pv) >1 "
|
||||
+ "GROUP BY department order by pv desc limit 10";
|
||||
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
|
||||
sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 "
|
||||
+ "GROUP BY department order by pv desc limit 10";
|
||||
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 and department = 'HR' "
|
||||
+ "GROUP BY department order by pv desc limit 10";
|
||||
replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND sum(pv) > 1 "
|
||||
+ "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -318,7 +378,7 @@ class SqlParserUpdateHelperTest {
|
||||
String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' "
|
||||
+ "order by sum(pv) desc limit 10";
|
||||
|
||||
List<String> groupByFields = new ArrayList<>();
|
||||
Set<String> groupByFields = new HashSet<>();
|
||||
groupByFields.add("department");
|
||||
|
||||
String replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields);
|
||||
@@ -342,8 +402,8 @@ class SqlParserUpdateHelperTest {
|
||||
String replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "AND 1 > 1 GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
||||
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 "
|
||||
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
|
||||
@@ -31,9 +31,10 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
||||
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.corrector.GlobalCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.HavingCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.TableCorrector
|
||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.TableCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
|
||||
@@ -31,9 +31,10 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
||||
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.corrector.GlobalCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.HavingCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.TableCorrector
|
||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.TableCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
|
||||
@@ -37,6 +37,8 @@ public class MetricResp extends SchemaItem {
|
||||
|
||||
private boolean hasAdminRes = false;
|
||||
|
||||
private String defaultAgg;
|
||||
|
||||
public void setTag(String tag) {
|
||||
if (StringUtils.isBlank(tag)) {
|
||||
tags = Lists.newArrayList();
|
||||
|
||||
@@ -18,6 +18,7 @@ import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.model.domain.Catalog;
|
||||
import com.tencent.supersonic.semantic.model.domain.DatabaseService;
|
||||
import com.tencent.supersonic.semantic.model.domain.DatasourceService;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
@@ -54,10 +55,13 @@ public class ModelServiceImpl implements ModelService {
|
||||
private final UserService userService;
|
||||
private final DatabaseService databaseService;
|
||||
|
||||
private final Catalog catalog;
|
||||
|
||||
public ModelServiceImpl(ModelRepository modelRepository, @Lazy MetricService metricService,
|
||||
@Lazy DimensionService dimensionService, @Lazy DatasourceService datasourceService,
|
||||
@Lazy DomainService domainService, UserService userService,
|
||||
@Lazy DatabaseService databaseService) {
|
||||
@Lazy DimensionService dimensionService, @Lazy DatasourceService datasourceService,
|
||||
@Lazy DomainService domainService, UserService userService,
|
||||
@Lazy DatabaseService databaseService,
|
||||
@Lazy Catalog catalog) {
|
||||
this.modelRepository = modelRepository;
|
||||
this.metricService = metricService;
|
||||
this.dimensionService = dimensionService;
|
||||
@@ -65,6 +69,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
this.domainService = domainService;
|
||||
this.userService = userService;
|
||||
this.databaseService = databaseService;
|
||||
this.catalog = catalog;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -166,7 +171,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
@Override
|
||||
public ModelResp getModel(Long id) {
|
||||
Map<Long, DomainResp> domainRespMap = domainService.getDomainList().stream()
|
||||
.collect(Collectors.toMap(DomainResp::getId, d -> d));
|
||||
.collect(Collectors.toMap(DomainResp::getId, d -> d));
|
||||
return ModelConvert.convert(getModelDO(id), domainRespMap);
|
||||
}
|
||||
|
||||
@@ -192,7 +197,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
return modelResps;
|
||||
}
|
||||
Map<Long, DomainResp> domainRespMap = domainService.getDomainList()
|
||||
.stream().collect(Collectors.toMap(DomainResp::getId, d -> d));
|
||||
.stream().collect(Collectors.toMap(DomainResp::getId, d -> d));
|
||||
return modelDOS.stream()
|
||||
.map(modelDO -> ModelConvert.convert(modelDO, domainRespMap))
|
||||
.collect(Collectors.toList());
|
||||
@@ -298,6 +303,8 @@ public class ModelServiceImpl implements ModelService {
|
||||
MetricSchemaResp metricSchemaDesc = new MetricSchemaResp();
|
||||
BeanUtils.copyProperties(metricDesc, metricSchemaDesc);
|
||||
metricSchemaDesc.setUseCnt(0L);
|
||||
String agg = catalog.getAgg(modelId, metricSchemaDesc.getBizName());
|
||||
metricSchemaDesc.setDefaultAgg(agg);
|
||||
metricSchemaDescList.add(metricSchemaDesc);
|
||||
}
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user