diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SchemaElement.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SchemaElement.java index cb86bd5c1..94abc0da5 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SchemaElement.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SchemaElement.java @@ -15,6 +15,7 @@ import lombok.NoArgsConstructor; @AllArgsConstructor @NoArgsConstructor public class SchemaElement implements Serializable { + private Long model; private Long id; private String name; @@ -26,6 +27,8 @@ public class SchemaElement implements Serializable { private List schemaValueMaps; + private String defaultAgg; + @Override public boolean equals(Object o) { if (this == o) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index 63edb2767..270f8126e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -5,13 +5,18 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; @Slf4j public abstract class BaseSemanticCorrector implements SemanticCorrector { @@ -38,4 +43,20 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return result; } + protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) { + Set selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql)); + Set whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql)); + + if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) { + return; + } + + whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); + whereFields.removeAll(selectFields); + whereFields.remove(TimeDimensionEnum.DAY.getName()); + whereFields.remove(TimeDimensionEnum.WEEK.getName()); + whereFields.remove(TimeDimensionEnum.MONTH.getName()); + String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); + semanticCorrectInfo.setSql(replaceFields); + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java new file mode 100644 index 000000000..5d73483b6 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java @@ -0,0 +1,29 @@ +package com.tencent.supersonic.chat.corrector; + +import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import java.util.Objects; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.expression.Expression; + +@Slf4j +public class GlobalAfterCorrector extends BaseSemanticCorrector { + + @Override + public void correct(SemanticCorrectInfo semanticCorrectInfo) { + + super.correct(semanticCorrectInfo); + String sql = semanticCorrectInfo.getSql(); + if (!SqlParserSelectHelper.hasAggregateFunction(sql)) { + return; + } + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + if (Objects.nonNull(havingExpression)) { + String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + semanticCorrectInfo.setSql(replaceSql); + } + return; + } + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java similarity index 69% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java index fbe6b0a83..36cd1bb71 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java @@ -1,13 +1,16 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.List; import java.util.Map; import java.util.Objects; @@ -17,10 +20,11 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @Slf4j -public class GlobalCorrector extends BaseSemanticCorrector { +public class GlobalBeforeCorrector extends BaseSemanticCorrector { @Override public void correct(SemanticCorrectInfo semanticCorrectInfo) { + super.correct(semanticCorrectInfo); replaceAlias(semanticCorrectInfo); @@ -33,12 +37,26 @@ public class GlobalCorrector extends BaseSemanticCorrector { } private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { + //add aggregate to all metric + String sql = semanticCorrectInfo.getSql(); + Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel(); - if (SqlParserSelectHelper.hasGroupBy(semanticCorrectInfo.getSql())) { + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + Map metricToAggregate = semanticSchema.getMetrics(modelId).stream() + .map(schemaElement -> { + if (Objects.isNull(schemaElement.getDefaultAgg())) { + schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name()); + } + return schemaElement; + }).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1)); + + if (CollectionUtils.isEmpty(metricToAggregate)) { return; } + String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate); + semanticCorrectInfo.setSql(aggregateSql); } private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java index c931d2f0f..25018fef4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -1,7 +1,17 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.knowledge.service.SchemaService; +import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; @Slf4j public class GroupByCorrector extends BaseSemanticCorrector { @@ -9,7 +19,25 @@ public class GroupByCorrector extends BaseSemanticCorrector { @Override public void correct(SemanticCorrectInfo semanticCorrectInfo) { + super.correct(semanticCorrectInfo); + //add aggregate to all metric + String sql = semanticCorrectInfo.getSql(); + Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel(); + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + + Set dimensions = semanticSchema.getDimensions(modelId).stream() + .filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName())) + .map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet()); + + List selectFields = SqlParserSelectHelper.getSelectFields(sql); + + if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) { + return; + } + Set groupByFields = selectFields.stream().filter(field -> dimensions.contains(field)) + .collect(Collectors.toSet()); + semanticCorrectInfo.setSql(SqlParserUpdateHelper.addGroupBy(sql, groupByFields)); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java index c5d8a514d..d1dc40d0b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -1,7 +1,14 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.knowledge.service.SchemaService; +import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; @Slf4j public class HavingCorrector extends BaseSemanticCorrector { @@ -9,6 +16,22 @@ public class HavingCorrector extends BaseSemanticCorrector { @Override public void correct(SemanticCorrectInfo semanticCorrectInfo) { + super.correct(semanticCorrectInfo); + + //add aggregate to all metric + semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql()); + Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel(); + + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + + Set metrics = semanticSchema.getMetrics(modelId).stream() + .map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet()); + + if (CollectionUtils.isEmpty(metrics)) { + return; + } + String havingSql = SqlParserUpdateHelper.addHaving(semanticCorrectInfo.getSql(), metrics); + semanticCorrectInfo.setSql(havingSql); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java index 303a498e5..05e39e902 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java @@ -1,16 +1,7 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.Objects; -import java.util.Set; import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.expression.Expression; -import org.springframework.util.CollectionUtils; @Slf4j public class SelectCorrector extends BaseSemanticCorrector { @@ -19,28 +10,6 @@ public class SelectCorrector extends BaseSemanticCorrector { public void correct(SemanticCorrectInfo semanticCorrectInfo) { super.correct(semanticCorrectInfo); String sql = semanticCorrectInfo.getSql(); - - if (SqlParserSelectHelper.hasAggregateFunction(sql)) { - Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); - if (Objects.nonNull(havingExpression)) { - String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); - semanticCorrectInfo.setSql(replaceSql); - } - return; - } - Set selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql)); - Set whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql)); - - if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) { - return; - } - - whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); - whereFields.removeAll(selectFields); - whereFields.remove(TimeDimensionEnum.DAY.getName()); - whereFields.remove(TimeDimensionEnum.WEEK.getName()); - whereFields.remove(TimeDimensionEnum.MONTH.getName()); - String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); - semanticCorrectInfo.setSql(replaceFields); + addFieldsToSelect(semanticCorrectInfo, sql); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java index e7d02977f..fd66fe767 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java @@ -128,10 +128,7 @@ public class LLMDslParser implements SemanticParser { public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) { - String correctorSql = semanticCorrectInfo.getPreSql(); - if (StringUtils.isEmpty(correctorSql)) { - correctorSql = semanticCorrectInfo.getSql(); - } + String correctorSql = semanticCorrectInfo.getSql(); parseInfo.getSqlInfo().setLogicSql(correctorSql); List expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); //set dataInfo diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java index f924acbeb..31ea5e864 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java @@ -120,7 +120,7 @@ public class SemanticService { return entityInfo; } catch (Exception e) { - log.error("setMaintModel error {}", e); + log.error("setMainModel error {}", e); } } } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java index 902c3c83a..6cff20149 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java @@ -2,19 +2,19 @@ package com.tencent.supersonic.knowledge.semantic; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq; import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; -import com.tencent.supersonic.semantic.api.model.response.ExplainResp; -import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; -import com.tencent.supersonic.semantic.api.model.response.ModelResp; -import com.tencent.supersonic.semantic.api.model.response.MetricResp; -import com.tencent.supersonic.semantic.api.model.response.DomainResp; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; +import com.tencent.supersonic.semantic.api.model.response.DomainResp; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; +import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; +import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; @@ -22,7 +22,6 @@ import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.model.domain.DimensionService; import com.tencent.supersonic.semantic.model.domain.MetricService; -import com.tencent.supersonic.semantic.model.domain.ModelService; import com.tencent.supersonic.semantic.query.service.QueryService; import com.tencent.supersonic.semantic.query.service.SchemaService; import java.util.List; @@ -33,7 +32,6 @@ import lombok.extern.slf4j.Slf4j; public class LocalSemanticLayer extends BaseSemanticLayer { private SchemaService schemaService; - private ModelService modelService; private DimensionService dimensionService; private MetricService metricService; private QueryService queryService; diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/ModelSchemaBuilder.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/ModelSchemaBuilder.java index ff05c84b8..b7ceed424 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/ModelSchemaBuilder.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/ModelSchemaBuilder.java @@ -53,6 +53,7 @@ public class ModelSchemaBuilder { .type(SchemaElementType.METRIC) .useCnt(metric.getUseCnt()) .alias(alias) + .defaultAgg(metric.getDefaultAgg()) .build(); metrics.add(metricToAdd); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java index abe60778f..b3f88d756 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java @@ -31,7 +31,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(MinorThan expr) { - List expressions = parserFilter(expr); + List expressions = parserFilter(expr, " 1 < 2 "); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -39,7 +39,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(EqualsTo expr) { - List expressions = parserFilter(expr); + List expressions = parserFilter(expr, " 1 = 1 "); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -47,7 +47,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(MinorThanEquals expr) { - List expressions = parserFilter(expr); + List expressions = parserFilter(expr, " 1 <= 1 "); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -56,7 +56,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(GreaterThan expr) { - List expressions = parserFilter(expr); + List expressions = parserFilter(expr, " 2 > 1 "); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -64,7 +64,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(GreaterThanEquals expr) { - List expressions = parserFilter(expr); + List expressions = parserFilter(expr, " 1 >= 1 "); if (Objects.nonNull(expressions)) { waitingForAdds.addAll(expressions); } @@ -75,7 +75,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { } - public List parserFilter(ComparisonOperator comparisonOperator) { + public List parserFilter(ComparisonOperator comparisonOperator, String condExpr) { List result = new ArrayList<>(); String toString = comparisonOperator.toString(); Expression leftExpression = comparisonOperator.getLeftExpression(); @@ -97,7 +97,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { return null; } try { - ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 "); + ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr); comparisonOperator.setLeftExpression(expression.getLeftExpression()); comparisonOperator.setRightExpression(expression.getRightExpression()); comparisonOperator.setASTNode(expression.getASTNode()); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index 48906a2d0..01fd38db8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -10,6 +10,9 @@ import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.conditional.OrExpression; +import net.sf.jsqlparser.expression.operators.conditional.XorExpression; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.schema.Column; @@ -281,10 +284,6 @@ public class SqlParserUpdateHelper { } public static String addAggregateToField(String sql, Map fieldNameToAggregate) { - if (SqlParserSelectHelper.hasGroupBy(sql)) { - return sql; - } - Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); @@ -296,13 +295,15 @@ public class SqlParserUpdateHelper { public void visit(PlainSelect plainSelect) { addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate); addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate); + addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate); + addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate); } }); return selectStatement.toString(); } - public static String addGroupBy(String sql, List groupByFields) { - if (SqlParserSelectHelper.hasGroupBy(sql)) { + public static String addGroupBy(String sql, Set groupByFields) { + if (SqlParserSelectHelper.hasGroupBy(sql) || CollectionUtils.isEmpty(groupByFields)) { return sql; } Select selectStatement = SqlParserSelectHelper.getSelect(sql); @@ -327,9 +328,8 @@ public class SqlParserUpdateHelper { if (selectItem instanceof SelectExpressionItem) { SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; Expression expression = selectExpressionItem.getExpression(); - String columnName = ((Column) expression).getColumnName(); - Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); - if (Objects.isNull(function)) { + Function function = getFunction(expression, fieldNameToAggregate); + if (function == null) { continue; } selectExpressionItem.setExpression(function); @@ -344,18 +344,102 @@ public class SqlParserUpdateHelper { } for (OrderByElement orderByElement : orderByElements) { Expression expression = orderByElement.getExpression(); - String columnName = ((Column) expression).getColumnName(); - if (StringUtils.isEmpty(columnName)) { - continue; - } - Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); - if (Objects.isNull(function)) { + Function function = getFunction(expression, fieldNameToAggregate); + if (function == null) { continue; } orderByElement.setExpression(function); } } + private static void addAggregateToGroupByItems(GroupByElement groupByElement, + Map fieldNameToAggregate) { + if (groupByElement == null) { + return; + } + for (Expression expression : groupByElement.getGroupByExpressions()) { + Function function = getFunction(expression, fieldNameToAggregate); + if (function == null) { + continue; + } + groupByElement.addGroupByExpression(function); + } + } + + private static void addAggregateToWhereItems(Expression whereExpression, Map fieldNameToAggregate) { + if (whereExpression == null) { + return; + } + modifyWhereExpression(whereExpression, fieldNameToAggregate); + } + + private static void modifyWhereExpression(Expression whereExpression, + Map fieldNameToAggregate) { + if (isLogicExpression(whereExpression)) { + AndExpression andExpression = (AndExpression) whereExpression; + Expression leftExpression = andExpression.getLeftExpression(); + Expression rightExpression = andExpression.getRightExpression(); + if (isLogicExpression(leftExpression)) { + modifyWhereExpression(leftExpression, fieldNameToAggregate); + } else { + setAggToFunction(leftExpression, fieldNameToAggregate); + } + if (isLogicExpression(rightExpression)) { + modifyWhereExpression(rightExpression, fieldNameToAggregate); + } else { + setAggToFunction(rightExpression, fieldNameToAggregate); + } + setAggToFunction(rightExpression, fieldNameToAggregate); + } else { + setAggToFunction(whereExpression, fieldNameToAggregate); + } + } + + private static boolean isLogicExpression(Expression whereExpression) { + return whereExpression instanceof AndExpression || (whereExpression instanceof OrExpression + || (whereExpression instanceof XorExpression)); + } + + + private static void setAggToFunction(Expression expression, Map fieldNameToAggregate) { + if (!(expression instanceof ComparisonOperator)) { + return; + } + ComparisonOperator comparisonOperator = (ComparisonOperator) expression; + if (comparisonOperator.getRightExpression() instanceof Column) { + String columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName(); + Function function = getFunction(comparisonOperator.getRightExpression(), + fieldNameToAggregate.get(columnName)); + if (Objects.nonNull(function)) { + comparisonOperator.setRightExpression(function); + } + } + if (comparisonOperator.getLeftExpression() instanceof Column) { + String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName(); + Function function = getFunction(comparisonOperator.getLeftExpression(), + fieldNameToAggregate.get(columnName)); + if (Objects.nonNull(function)) { + comparisonOperator.setLeftExpression(function); + } + } + } + + + private static Function getFunction(Expression expression, Map fieldNameToAggregate) { + if (!(expression instanceof Column)) { + return null; + } + String columnName = ((Column) expression).getColumnName(); + if (StringUtils.isEmpty(columnName)) { + return null; + } + Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); + if (Objects.isNull(function)) { + return null; + } + return function; + } + private static Function getFunction(Expression expression, String aggregateName) { if (StringUtils.isEmpty(aggregateName)) { return null; diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index 8745235e8..9164176df 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -301,7 +301,7 @@ class SqlParserUpdateHelperTest { Map filedNameToAggregate = new HashMap<>(); filedNameToAggregate.put("pv", "sum"); - List groupByFields = new ArrayList<>(); + Set groupByFields = new HashSet<>(); groupByFields.add("department"); String replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); @@ -311,6 +311,66 @@ class SqlParserUpdateHelperTest { "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); + + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " + + "order by pv desc limit 10"; + replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where pv >1 order by pv desc limit 10"; + replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where sum(pv) >1 order by pv desc limit 10"; + replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + + sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and sum(pv) >1 " + + "GROUP BY department order by pv desc limit 10"; + replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " + + "GROUP BY department order by pv desc limit 10"; + replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 and department = 'HR' " + + "GROUP BY department order by pv desc limit 10"; + replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND sum(pv) > 1 " + + "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); } @Test @@ -318,7 +378,7 @@ class SqlParserUpdateHelperTest { String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' " + "order by sum(pv) desc limit 10"; - List groupByFields = new ArrayList<>(); + Set groupByFields = new HashSet<>(); groupByFields.add("department"); String replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields); @@ -342,8 +402,8 @@ class SqlParserUpdateHelperTest { String replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames); Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "AND 1 > 1 GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 " + + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", replaceSql); } diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index 4835d1225..98f9ca666 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -31,9 +31,10 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ - com.tencent.supersonic.chat.corrector.GlobalCorrector, \ - com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \ com.tencent.supersonic.chat.corrector.SelectCorrector, \ com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.HavingCorrector, \ - com.tencent.supersonic.chat.corrector.TableCorrector \ No newline at end of file + com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.TableCorrector, \ + com.tencent.supersonic.chat.corrector.GlobalAfterCorrector \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index ed43f0e5a..cb441f5bd 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -31,9 +31,10 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor com.tencent.supersonic.chat.api.component.SemanticCorrector=\ - com.tencent.supersonic.chat.corrector.GlobalCorrector, \ - com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \ com.tencent.supersonic.chat.corrector.SelectCorrector, \ com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.HavingCorrector, \ - com.tencent.supersonic.chat.corrector.TableCorrector \ No newline at end of file + com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.TableCorrector, \ + com.tencent.supersonic.chat.corrector.GlobalAfterCorrector \ No newline at end of file diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java index 0a2aecd24..4307006ad 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java @@ -37,6 +37,8 @@ public class MetricResp extends SchemaItem { private boolean hasAdminRes = false; + private String defaultAgg; + public void setTag(String tag) { if (StringUtils.isBlank(tag)) { tags = Lists.newArrayList(); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java index aedaf2250..847793046 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java @@ -18,6 +18,7 @@ import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp; import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; +import com.tencent.supersonic.semantic.model.domain.Catalog; import com.tencent.supersonic.semantic.model.domain.DatabaseService; import com.tencent.supersonic.semantic.model.domain.DatasourceService; import com.tencent.supersonic.semantic.model.domain.DimensionService; @@ -54,10 +55,13 @@ public class ModelServiceImpl implements ModelService { private final UserService userService; private final DatabaseService databaseService; + private final Catalog catalog; + public ModelServiceImpl(ModelRepository modelRepository, @Lazy MetricService metricService, - @Lazy DimensionService dimensionService, @Lazy DatasourceService datasourceService, - @Lazy DomainService domainService, UserService userService, - @Lazy DatabaseService databaseService) { + @Lazy DimensionService dimensionService, @Lazy DatasourceService datasourceService, + @Lazy DomainService domainService, UserService userService, + @Lazy DatabaseService databaseService, + @Lazy Catalog catalog) { this.modelRepository = modelRepository; this.metricService = metricService; this.dimensionService = dimensionService; @@ -65,6 +69,7 @@ public class ModelServiceImpl implements ModelService { this.domainService = domainService; this.userService = userService; this.databaseService = databaseService; + this.catalog = catalog; } @Override @@ -166,7 +171,7 @@ public class ModelServiceImpl implements ModelService { @Override public ModelResp getModel(Long id) { Map domainRespMap = domainService.getDomainList().stream() - .collect(Collectors.toMap(DomainResp::getId, d -> d)); + .collect(Collectors.toMap(DomainResp::getId, d -> d)); return ModelConvert.convert(getModelDO(id), domainRespMap); } @@ -192,7 +197,7 @@ public class ModelServiceImpl implements ModelService { return modelResps; } Map domainRespMap = domainService.getDomainList() - .stream().collect(Collectors.toMap(DomainResp::getId, d -> d)); + .stream().collect(Collectors.toMap(DomainResp::getId, d -> d)); return modelDOS.stream() .map(modelDO -> ModelConvert.convert(modelDO, domainRespMap)) .collect(Collectors.toList()); @@ -298,6 +303,8 @@ public class ModelServiceImpl implements ModelService { MetricSchemaResp metricSchemaDesc = new MetricSchemaResp(); BeanUtils.copyProperties(metricDesc, metricSchemaDesc); metricSchemaDesc.setUseCnt(0L); + String agg = catalog.getAgg(modelId, metricSchemaDesc.getBizName()); + metricSchemaDesc.setDefaultAgg(agg); metricSchemaDescList.add(metricSchemaDesc); } );