diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index d5443c328..1fec3a61e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -594,7 +594,14 @@ public class SqlReplaceHelper { Select selectStatement = SqlSelectHelper.getSelect(sql); List plainSelectList = new ArrayList<>(); if (selectStatement instanceof PlainSelect) { - plainSelectList.add((PlainSelect) selectStatement); + // if with statement exists, replace expression in the with statement. + if (!CollectionUtils.isEmpty(selectStatement.getWithItemsList())) { + selectStatement.getWithItemsList().forEach(withItem -> { + plainSelectList.add(withItem.getSelect().getPlainSelect()); + }); + } else { + plainSelectList.add((PlainSelect) selectStatement); + } } else if (selectStatement instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectStatement; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { @@ -606,9 +613,13 @@ public class SqlReplaceHelper { } else { return sql; } + List plainSelects = SqlSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect plainSelect : plainSelects) { replacePlainSelectByExpr(plainSelect, replace); + if (SqlSelectHelper.hasAggregateFunction(plainSelect)) { + SqlSelectHelper.addMissingGroupby(plainSelect); + } } return selectStatement.toString(); } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index 30812cb5e..772d0294b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -914,4 +914,31 @@ public class SqlSelectHelper { } }); } + + public static void addMissingGroupby(PlainSelect plainSelect) { + if (Objects.nonNull(plainSelect.getGroupBy()) + && !plainSelect.getGroupBy().getGroupByExpressionList().isEmpty()) { + return; + } + GroupByElement groupBy = new GroupByElement(); + for (SelectItem selectItem : plainSelect.getSelectItems()) { + Expression expression = selectItem.getExpression(); + if (expression instanceof Column) { + groupBy.addGroupByExpression(expression); + } + } + if (!groupBy.getGroupByExpressionList().isEmpty()) { + plainSelect.setGroupByElement(groupBy); + } + } + + public static boolean hasAggregateFunction(PlainSelect plainSelect) { + List> selectItems = plainSelect.getSelectItems(); + FunctionVisitor visitor = new FunctionVisitor(); + for (SelectItem selectItem : selectItems) { + selectItem.accept(visitor); + } + return !visitor.getFunctionNames().isEmpty(); + } + } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/OntologyQuery.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/OntologyQuery.java index cc5dd910f..7b8d8c772 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/OntologyQuery.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/OntologyQuery.java @@ -4,7 +4,6 @@ import com.google.common.collect.Sets; import com.tencent.supersonic.common.pojo.ColumnOrder; import com.tencent.supersonic.headless.api.pojo.enums.AggOption; import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; -import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; import lombok.Data; @@ -18,16 +17,11 @@ public class OntologyQuery { private Set metrics = Sets.newHashSet(); private Set dimensions = Sets.newHashSet(); private Set fields = Sets.newHashSet(); - private String where; private Long limit; private List order; private boolean nativeQuery = true; private AggOption aggOption = AggOption.NATIVE; - public boolean hasDerivedMetric() { - return metrics.stream().anyMatch(MetricResp::isDerived); - } - public Set getMetricsByModel(Long modelId) { return metrics.stream().filter(m -> m.getModelId().equals(modelId)) .collect(Collectors.toSet()); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java index 51049db8d..e780f7aac 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java @@ -44,6 +44,7 @@ public class DefaultSemanticTranslator implements SemanticTranslator { for (QueryOptimizer queryOptimizer : ComponentFactory.getQueryOptimizers()) { queryOptimizer.rewrite(queryStatement); } + log.info("translated query SQL: [{}]", queryStatement.getSql()); } catch (Exception e) { queryStatement.setErrMsg(e.getMessage()); log.error("Failed to translate query [{}]", e.getMessage(), e); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/MetricExpressionConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/MetricExpressionConverter.java index f6950bfd4..3556b018e 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/MetricExpressionConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/MetricExpressionConverter.java @@ -48,9 +48,6 @@ public class MetricExpressionConverter implements QueryConverter { private Map getMetricExpressions(SemanticSchemaResp semanticSchema, OntologyQuery ontologyQuery) { - if (!ontologyQuery.hasDerivedMetric()) { - return Collections.emptyMap(); - } List allMetrics = semanticSchema.getMetrics(); List allDimensions = semanticSchema.getDimensions(); @@ -73,14 +70,14 @@ public class MetricExpressionConverter implements QueryConverter { Map visitedMetrics = new HashMap<>(); Map metric2Expr = new HashMap<>(); for (MetricSchemaResp queryMetric : queryMetrics) { + String fieldExpr = buildFieldExpr(allMetrics, allFields, allMeasures, allDimensions, + queryMetric.getExpr(), queryMetric.getMetricDefineType(), aggOption, + visitedMetrics, queryDimensions, queryFields); + // add all fields referenced in the expression + queryMetric.getFields().addAll(SqlSelectHelper.getFieldsFromExpr(fieldExpr)); + log.debug("derived metric {}->{}", queryMetric.getBizName(), fieldExpr); if (queryMetric.isDerived()) { - String fieldExpr = buildFieldExpr(allMetrics, allFields, allMeasures, allDimensions, - queryMetric.getExpr(), queryMetric.getMetricDefineType(), aggOption, - visitedMetrics, queryDimensions, queryFields); metric2Expr.put(queryMetric.getBizName(), fieldExpr); - // add all fields referenced in the expression - queryMetric.getFields().addAll(SqlSelectHelper.getFieldsFromExpr(fieldExpr)); - log.debug("derived metric {}->{}", queryMetric.getBizName(), fieldExpr); } } @@ -120,13 +117,13 @@ public class MetricExpressionConverter implements QueryConverter { Measure measure = allMeasures.get(field); if (AggOperatorEnum.COUNT_DISTINCT.getOperator() .equalsIgnoreCase(measure.getAgg())) { - return AggOption.NATIVE.equals(aggOption) ? measure.getBizName() + return AggOption.NATIVE.equals(aggOption) ? measure.getExpr() : AggOperatorEnum.COUNT.getOperator() + " ( " - + AggOperatorEnum.DISTINCT + " " - + measure.getBizName() + " ) "; + + AggOperatorEnum.DISTINCT + " " + measure.getExpr() + + " ) "; } - String expr = AggOption.NATIVE.equals(aggOption) ? measure.getBizName() - : measure.getAgg() + " ( " + measure.getBizName() + " ) "; + String expr = AggOption.NATIVE.equals(aggOption) ? measure.getExpr() + : measure.getAgg() + " ( " + measure.getExpr() + " ) "; replace.put(field, expr); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java index e9c8c6019..f31e863af 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java @@ -71,7 +71,6 @@ public class SqlQueryConverter implements QueryConverter { } else if (sqlQueryAggOption.equals(AggOption.NATIVE) && !queryMetrics.isEmpty()) { ontologyQuery.setAggOption(AggOption.DEFAULT); } - ontologyQuery.setNativeQuery(!AggOption.isAgg(ontologyQuery.getAggOption())); queryStatement.setOntologyQuery(ontologyQuery); log.info("parse sqlQuery [{}] ", sqlQuery); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/StructQueryConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/StructQueryConverter.java index 59480d536..45de899bf 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/StructQueryConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/StructQueryConverter.java @@ -35,8 +35,9 @@ public class StructQueryConverter implements QueryConverter { } SqlQuery sqlQuery = new SqlQuery(); sqlQuery.setTable(dsTable); - String sql = String.format("select %s from %s %s %s %s", + String sql = String.format("select %s from %s %s %s %s %s", sqlGenerateUtils.getSelect(structQuery), dsTable, + sqlGenerateUtils.generateWhere(structQuery, null), sqlGenerateUtils.getGroupBy(structQuery), sqlGenerateUtils.getOrderBy(structQuery), sqlGenerateUtils.getLimit(structQuery)); Database database = queryStatement.getOntology().getDatabase(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/SqlBuilder.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/SqlBuilder.java index b37e47f49..e46923e8e 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/SqlBuilder.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/SqlBuilder.java @@ -2,16 +2,14 @@ package com.tencent.supersonic.headless.core.translator.parser.calcite; import com.tencent.supersonic.common.calcite.Configuration; import com.tencent.supersonic.common.pojo.enums.EngineType; -import com.tencent.supersonic.headless.api.pojo.enums.AggOption; import com.tencent.supersonic.headless.core.pojo.DataModel; import com.tencent.supersonic.headless.core.pojo.Database; import com.tencent.supersonic.headless.core.pojo.OntologyQuery; import com.tencent.supersonic.headless.core.pojo.QueryStatement; import com.tencent.supersonic.headless.core.translator.parser.calcite.node.DataModelNode; import com.tencent.supersonic.headless.core.translator.parser.calcite.node.SemanticNode; +import com.tencent.supersonic.headless.core.translator.parser.calcite.render.JoinRender; import com.tencent.supersonic.headless.core.translator.parser.calcite.render.Renderer; -import com.tencent.supersonic.headless.core.translator.parser.calcite.render.SourceRender; -import com.tencent.supersonic.headless.core.translator.parser.s2sql.Constants; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParser; @@ -29,8 +27,6 @@ public class SqlBuilder { private OntologyQuery ontologyQuery; private SqlValidatorScope scope; private SqlNode parserNode; - private boolean isAgg = false; - private AggOption aggOption = AggOption.DEFAULT; public SqlBuilder(S2CalciteSchema schema) { this.schema = schema; @@ -41,7 +37,6 @@ public class SqlBuilder { if (ontologyQuery.getLimit() == null) { ontologyQuery.setLimit(0L); } - this.aggOption = ontologyQuery.getAggOption(); buildParseNode(); Database database = queryStatement.getOntology().getDatabase(); @@ -57,41 +52,25 @@ public class SqlBuilder { if (dataModels == null || dataModels.isEmpty()) { throw new Exception("data model not found"); } - isAgg = getAgg(dataModels.get(0)); - // build level by level LinkedList builders = new LinkedList<>(); - builders.add(new SourceRender()); + builders.add(new JoinRender()); ListIterator it = builders.listIterator(); int i = 0; Renderer previous = null; while (it.hasNext()) { Renderer renderer = it.next(); if (previous != null) { - previous.render(ontologyQuery, dataModels, scope, schema, !isAgg); + previous.render(ontologyQuery, dataModels, scope, schema); renderer.setTable(previous.builderAs(DataModelNode.getNames(dataModels) + "_" + i)); i++; } previous = renderer; } - builders.getLast().render(ontologyQuery, dataModels, scope, schema, !isAgg); + builders.getLast().render(ontologyQuery, dataModels, scope, schema); parserNode = builders.getLast().build(); } - private boolean getAgg(DataModel dataModel) { - if (!AggOption.DEFAULT.equals(aggOption)) { - return AggOption.isAgg(aggOption); - } - // default by dataModel time aggregation - if (Objects.nonNull(dataModel.getAggTime()) && !dataModel.getAggTime() - .equalsIgnoreCase(Constants.DIMENSION_TYPE_TIME_GRANULARITY_NONE)) { - if (!ontologyQuery.isNativeQuery()) { - return true; - } - } - return isAgg; - } - public String getSql(EngineType engineType) { return SemanticNode.getSql(parserNode, engineType); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/TableView.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/TableView.java index ef1522f58..81c8c7373 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/TableView.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/TableView.java @@ -4,13 +4,14 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.tencent.supersonic.headless.core.pojo.DataModel; import lombok.Data; -import org.apache.calcite.sql.*; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.parser.SqlParserPos; import java.util.ArrayList; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; /** basic query project */ @Data diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/node/DataModelNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/node/DataModelNode.java index 7d4f97132..459db41aa 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/node/DataModelNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/node/DataModelNode.java @@ -9,7 +9,6 @@ import com.tencent.supersonic.common.pojo.enums.EngineType; import com.tencent.supersonic.headless.api.pojo.Identify; import com.tencent.supersonic.headless.api.pojo.Measure; import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; -import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; import com.tencent.supersonic.headless.core.pojo.DataModel; import com.tencent.supersonic.headless.core.pojo.JoinRelation; import com.tencent.supersonic.headless.core.pojo.Ontology; @@ -101,19 +100,6 @@ public class DataModelNode extends SemanticNode { dateInfo, dimensions, metrics); } - public static SqlNode buildExtend(DataModel datasource, Map exprList, - SqlValidatorScope scope) throws Exception { - if (CollectionUtils.isEmpty(exprList)) { - return build(datasource, scope); - } - EngineType engineType = EngineType.fromString(datasource.getType()); - SqlNode dataSet = new SqlBasicCall(new LateralViewExplodeNode(exprList), - Arrays.asList(build(datasource, scope), new SqlNodeList( - getExtendField(exprList, scope, engineType), SqlParserPos.ZERO)), - SqlParserPos.ZERO); - return buildAs(datasource.getName() + Constants.DIMENSION_ARRAY_SINGLE_SUFFIX, dataSet); - } - public static List getExtendField(Map exprList, SqlValidatorScope scope, EngineType engineType) throws Exception { List sqlNodeList = new ArrayList<>(); @@ -153,32 +139,6 @@ public class DataModelNode extends SemanticNode { }); } - public static void mergeQueryFilterDimensionMeasure(Ontology ontology, - OntologyQuery ontologyQuery, Set dimensions, Set measures, - SqlValidatorScope scope) throws Exception { - EngineType engineType = ontology.getDatabase().getType(); - if (Objects.nonNull(ontologyQuery.getWhere()) && !ontologyQuery.getWhere().isEmpty()) { - Set filterConditions = new HashSet<>(); - FilterNode.getFilterField(parse(ontologyQuery.getWhere(), scope, engineType), - filterConditions); - Set queryMeasures = new HashSet<>(measures); - Set schemaMetricName = ontology.getMetrics().stream() - .map(MetricSchemaResp::getName).collect(Collectors.toSet()); - for (String filterCondition : filterConditions) { - if (schemaMetricName.contains(filterCondition)) { - ontology.getMetrics().stream() - .filter(m -> m.getName().equalsIgnoreCase(filterCondition)) - .forEach(m -> m.getMetricDefineByMeasureParams().getMeasures() - .forEach(mm -> queryMeasures.add(mm.getName()))); - continue; - } - dimensions.add(filterCondition); - } - measures.clear(); - measures.addAll(queryMeasures); - } - } - public static List getQueryDataModelsV2(Ontology ontology, OntologyQuery query) { // first, sort models based on the number of query metrics Map modelMetricCount = Maps.newHashMap(); @@ -209,8 +169,8 @@ public class DataModelNode extends SemanticNode { .collect(Collectors.toList()); Set dataModelNames = Sets.newLinkedHashSet(); - dataModelNames.addAll(metricsDataModels); dataModelNames.addAll(dimDataModels); + dataModelNames.addAll(metricsDataModels); return dataModelNames.stream().map(bizName -> ontology.getDataModelMap().get(bizName)) .collect(Collectors.toList()); } @@ -289,45 +249,6 @@ public class DataModelNode extends SemanticNode { return dataModel; } - private static DataModel findBaseModel(Ontology ontology, Set queryMeasures, - Set queryDimensions) { - DataModel dataModel = null; - // first, try to find the model with the most matching measures - Map dataModelMeasuresCount = new HashMap<>(); - for (Map.Entry entry : ontology.getDataModelMap().entrySet()) { - Set sourceMeasure = entry.getValue().getMeasures().stream() - .map(Measure::getName).collect(Collectors.toSet()); - sourceMeasure.retainAll(queryMeasures); - dataModelMeasuresCount.put(entry.getKey(), sourceMeasure.size()); - } - log.info("dataModelMeasureCount: [{}]", dataModelMeasuresCount); - Optional> base = - dataModelMeasuresCount.entrySet().stream().filter(e -> e.getValue() > 0) - .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst(); - - if (base.isPresent()) { - dataModel = ontology.getDataModelMap().get(base.get().getKey()); - } else { - // second, try to find the model with the most matching dimensions - Map dataModelDimCount = new HashMap<>(); - for (Map.Entry> entry : ontology.getDimensionMap() - .entrySet()) { - Set modelDimensions = entry.getValue().stream().map(DimSchemaResp::getName) - .collect(Collectors.toSet()); - modelDimensions.retainAll(queryDimensions); - dataModelDimCount.put(entry.getKey(), modelDimensions.size()); - } - log.info("dataModelDimCount: [{}]", dataModelDimCount); - base = dataModelDimCount.entrySet().stream().filter(e -> e.getValue() > 0) - .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst(); - if (base.isPresent()) { - dataModel = ontology.getDataModelMap().get(base.get().getKey()); - } - } - - return dataModel; - } - private static boolean checkMatch(DataModel baseDataModel, Set queryMeasures, Set queryDimension) { boolean isAllMatch = true; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/JoinRender.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/JoinRender.java index 03bc222d9..b4e7e3748 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/JoinRender.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/JoinRender.java @@ -9,10 +9,10 @@ import com.tencent.supersonic.headless.core.pojo.JoinRelation; import com.tencent.supersonic.headless.core.pojo.OntologyQuery; import com.tencent.supersonic.headless.core.translator.parser.calcite.S2CalciteSchema; import com.tencent.supersonic.headless.core.translator.parser.calcite.TableView; +import com.tencent.supersonic.headless.core.translator.parser.calcite.node.DataModelNode; import com.tencent.supersonic.headless.core.translator.parser.calcite.node.IdentifyNode; import com.tencent.supersonic.headless.core.translator.parser.calcite.node.SemanticNode; import com.tencent.supersonic.headless.core.translator.parser.s2sql.Constants; -import com.tencent.supersonic.headless.core.translator.parser.s2sql.Materialization; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.*; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -31,7 +31,7 @@ public class JoinRender extends Renderer { @Override public void render(OntologyQuery ontologyQuery, List dataModels, - SqlValidatorScope scope, S2CalciteSchema schema, boolean nonAgg) throws Exception { + SqlValidatorScope scope, S2CalciteSchema schema) throws Exception { SqlNode left = null; TableView leftTable = null; Map outerSelect = new HashMap<>(); @@ -50,7 +50,8 @@ public class JoinRender extends Renderer { primary.add(identify.getName()); } - TableView tableView = SourceRender.renderOne(queryMetrics, queryDimensions, dataModel, scope, schema); + TableView tableView = + renderOne(queryMetrics, queryDimensions, dataModel, scope, schema); log.info("tableView {}", StringUtils.normalizeSpace(tableView.getTable().toString())); String alias = Constants.JOIN_TABLE_PREFIX + dataModel.getName(); tableView.setAlias(alias); @@ -60,14 +61,13 @@ public class JoinRender extends Renderer { outerSelect.put(field, SemanticNode.parse(alias + "." + field, scope, engineType)); } if (left == null) { - leftTable = tableView; left = SemanticNode.buildAs(tableView.getAlias(), getTable(tableView)); - beforeModels.put(dataModel.getName(), leftTable.getAlias()); - continue; + } else { + left = buildJoin(left, leftTable, tableView, beforeModels, dataModel, schema, + scope); } - left = buildJoin(left, leftTable, tableView, beforeModels, dataModel, schema, scope); leftTable = tableView; - beforeModels.put(dataModel.getName(), tableView.getAlias()); + beforeModels.put(dataModel.getName(), leftTable.getAlias()); } for (Map.Entry entry : outerSelect.entrySet()) { @@ -84,28 +84,16 @@ public class JoinRender extends Renderer { Map before, DataModel dataModel, S2CalciteSchema schema, SqlValidatorScope scope) throws Exception { EngineType engineType = schema.getOntology().getDatabase().getType(); - SqlNode condition = getCondition(leftTable, rightTable, dataModel, schema, scope, engineType); + SqlNode condition = + getCondition(leftTable, rightTable, dataModel, schema, scope, engineType); SqlLiteral sqlLiteral = SemanticNode.getJoinSqlLiteral(""); JoinRelation matchJoinRelation = getMatchJoinRelation(before, rightTable, schema); - SqlNode joinRelationCondition = null; + SqlNode joinRelationCondition; if (!CollectionUtils.isEmpty(matchJoinRelation.getJoinCondition())) { sqlLiteral = SemanticNode.getJoinSqlLiteral(matchJoinRelation.getJoinType()); joinRelationCondition = getCondition(matchJoinRelation, scope, engineType); condition = joinRelationCondition; } - if (Materialization.TimePartType.ZIPPER.equals(leftTable.getDataModel().getTimePartType()) - || Materialization.TimePartType.ZIPPER - .equals(rightTable.getDataModel().getTimePartType())) { - SqlNode zipperCondition = - getZipperCondition(leftTable, rightTable, dataModel, schema, scope); - if (Objects.nonNull(joinRelationCondition)) { - condition = new SqlBasicCall(SqlStdOperatorTable.AND, - new ArrayList<>(Arrays.asList(zipperCondition, joinRelationCondition)), - SqlParserPos.ZERO, null); - } else { - condition = zipperCondition; - } - } return new SqlJoin(SqlParserPos.ZERO, leftNode, SqlLiteral.createBoolean(false, SqlParserPos.ZERO), sqlLiteral, @@ -172,7 +160,7 @@ public class JoinRender extends Renderer { selectLeft.retainAll(selectRight); SqlNode condition = null; for (String on : selectLeft) { - if (!SourceRender.isDimension(on, dataModel, schema)) { + if (!isDimension(on, dataModel, schema)) { continue; } if (IdentifyNode.isForeign(on, left.getDataModel().getIdentifiers())) { @@ -202,104 +190,47 @@ public class JoinRender extends Renderer { return condition; } - private static void joinOrder(int cnt, String id, Map> next, - Queue orders, Map visited) { - visited.put(id, true); - orders.add(id); - if (orders.size() >= cnt) { - return; + + public static TableView renderOne(Set queryMetrics, + Set queryDimensions, DataModel dataModel, SqlValidatorScope scope, + S2CalciteSchema schema) { + TableView tableView = new TableView(); + EngineType engineType = schema.getOntology().getDatabase().getType(); + Set queryFields = tableView.getFields(); + queryMetrics.stream().forEach(m -> queryFields.addAll(m.getFields())); + queryDimensions.stream().forEach(m -> queryFields.add(m.getBizName())); + + try { + for (String field : queryFields) { + tableView.getSelect().add(SemanticNode.parse(field, scope, engineType)); + } + tableView.setTable(DataModelNode.build(dataModel, scope)); + } catch (Exception e) { + log.error("Failed to create sqlNode for data model {}", dataModel); } - for (String nextId : next.get(id)) { - if (!visited.get(nextId)) { - joinOrder(cnt, nextId, next, orders, visited); - if (orders.size() >= cnt) { - return; - } + + return tableView; + } + + public static boolean isDimension(String name, DataModel dataModel, S2CalciteSchema schema) { + Optional dimension = dataModel.getDimensions().stream() + .filter(d -> d.getName().equalsIgnoreCase(name)).findFirst(); + if (dimension.isPresent()) { + return true; + } + Optional identify = dataModel.getIdentifiers().stream() + .filter(i -> i.getName().equalsIgnoreCase(name)).findFirst(); + if (identify.isPresent()) { + return true; + } + if (schema.getDimensions().containsKey(dataModel.getName())) { + Optional dataSourceDim = schema.getDimensions().get(dataModel.getName()) + .stream().filter(d -> d.getName().equalsIgnoreCase(name)).findFirst(); + if (dataSourceDim.isPresent()) { + return true; } } - orders.poll(); - visited.put(id, false); + return false; } - private void addZipperField(DataModel dataModel, List fields) { - // if (Materialization.TimePartType.ZIPPER.equals(dataModel.getTimePartType())) { - // dataModel.getDimensions().stream() - // .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) - // .forEach(t -> { - // if (t.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_END) - // && !fields.contains(t.getName())) { - // fields.add(t.getName()); - // } - // if (t.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_START) - // && !fields.contains(t.getName())) { - // fields.add(t.getName()); - // } - // }); - // } - } - - private SqlNode getZipperCondition(TableView left, TableView right, DataModel dataModel, - S2CalciteSchema schema, SqlValidatorScope scope) throws Exception { - // if (Materialization.TimePartType.ZIPPER.equals(left.getDataModel().getTimePartType()) - // && Materialization.TimePartType.ZIPPER - // .equals(right.getDataModel().getTimePartType())) { - // throw new Exception("not support two zipper table"); - // } - SqlNode condition = null; - // Optional leftTime = left.getDataModel().getDimensions().stream() - // .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) - // .findFirst(); - // Optional rightTime = right.getDataModel().getDimensions().stream() - // .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) - // .findFirst(); - // if (leftTime.isPresent() && rightTime.isPresent()) { - // - // String startTime = ""; - // String endTime = ""; - // String dateTime = ""; - // - // Optional startTimeOp = (Materialization.TimePartType.ZIPPER - // .equals(left.getDataModel().getTimePartType()) ? left : right).getDataModel() - // .getDimensions().stream() - // .filter(d -> Constants.DIMENSION_TYPE_TIME - // .equalsIgnoreCase(d.getType())) - // .filter(d -> d.getName() - // .startsWith(Constants.MATERIALIZATION_ZIPPER_START)) - // .findFirst(); - // Optional endTimeOp = (Materialization.TimePartType.ZIPPER - // .equals(left.getDataModel().getTimePartType()) ? left : right).getDataModel() - // .getDimensions().stream() - // .filter(d -> Constants.DIMENSION_TYPE_TIME - // .equalsIgnoreCase(d.getType())) - // .filter(d -> d.getName() - // .startsWith(Constants.MATERIALIZATION_ZIPPER_END)) - // .findFirst(); - // if (startTimeOp.isPresent() && endTimeOp.isPresent()) { - // TableView zipper = Materialization.TimePartType.ZIPPER - // .equals(left.getDataModel().getTimePartType()) ? left : right; - // TableView partMetric = Materialization.TimePartType.ZIPPER - // .equals(left.getDataModel().getTimePartType()) ? right : left; - // Optional partTime = Materialization.TimePartType.ZIPPER - // .equals(left.getDataModel().getTimePartType()) ? rightTime : leftTime; - // startTime = zipper.getAlias() + "." + startTimeOp.get().getName(); - // endTime = zipper.getAlias() + "." + endTimeOp.get().getName(); - // dateTime = partMetric.getAlias() + "." + partTime.get().getName(); - // } - // EngineType engineType = schema.getOntology().getDatabase().getType(); - // ArrayList operandList = - // new ArrayList<>(Arrays.asList(SemanticNode.parse(endTime, scope, engineType), - // SemanticNode.parse(dateTime, scope, engineType))); - // condition = new SqlBasicCall(SqlStdOperatorTable.AND, - // new ArrayList(Arrays.asList( - // new SqlBasicCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, - // new ArrayList(Arrays.asList( - // SemanticNode.parse(startTime, scope, engineType), - // SemanticNode.parse(dateTime, scope, engineType))), - // SqlParserPos.ZERO, null), - // new SqlBasicCall(SqlStdOperatorTable.GREATER_THAN, operandList, - // SqlParserPos.ZERO, null))), - // SqlParserPos.ZERO, null); - // } - return condition; - } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/OutputRender.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/OutputRender.java deleted file mode 100644 index 0bd3dcb3a..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/OutputRender.java +++ /dev/null @@ -1,59 +0,0 @@ -package com.tencent.supersonic.headless.core.translator.parser.calcite.render; - -import com.tencent.supersonic.common.pojo.ColumnOrder; -import com.tencent.supersonic.common.pojo.enums.EngineType; -import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; -import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; -import com.tencent.supersonic.headless.core.pojo.DataModel; -import com.tencent.supersonic.headless.core.pojo.OntologyQuery; -import com.tencent.supersonic.headless.core.translator.parser.calcite.S2CalciteSchema; -import com.tencent.supersonic.headless.core.translator.parser.calcite.node.MetricNode; -import com.tencent.supersonic.headless.core.translator.parser.calcite.node.SemanticNode; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlNodeList; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.validate.SqlValidatorScope; -import org.springframework.util.CollectionUtils; - -import java.util.ArrayList; -import java.util.List; - -/** process the query result items from query request */ -public class OutputRender extends Renderer { - - @Override - public void render(OntologyQuery ontologyQuery, List dataModels, - SqlValidatorScope scope, S2CalciteSchema schema, boolean nonAgg) throws Exception { - EngineType engineType = schema.getOntology().getDatabase().getType(); - for (DimSchemaResp dimension : ontologyQuery.getDimensions()) { - tableView.getMetric().add(SemanticNode.parse(dimension.getExpr(), scope, engineType)); - } - for (MetricSchemaResp metric : ontologyQuery.getMetrics()) { - if (MetricNode.isMetricField(metric.getName(), schema)) { - // metric from field ignore - continue; - } - tableView.getMetric().add(SemanticNode.parse(metric.getName(), scope, engineType)); - } - - if (ontologyQuery.getLimit() > 0) { - SqlNode offset = - SemanticNode.parse(ontologyQuery.getLimit().toString(), scope, engineType); - tableView.setOffset(offset); - } - if (!CollectionUtils.isEmpty(ontologyQuery.getOrder())) { - List orderList = new ArrayList<>(); - for (ColumnOrder columnOrder : ontologyQuery.getOrder()) { - if (SqlStdOperatorTable.DESC.getName().equalsIgnoreCase(columnOrder.getOrder())) { - orderList.add(SqlStdOperatorTable.DESC.createCall(SqlParserPos.ZERO, - new SqlNode[] {SemanticNode.parse(columnOrder.getCol(), scope, - engineType)})); - } else { - orderList.add(SemanticNode.parse(columnOrder.getCol(), scope, engineType)); - } - } - tableView.setOrder(new SqlNodeList(orderList, SqlParserPos.ZERO)); - } - } -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/Renderer.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/Renderer.java index d7042e487..49390fc60 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/Renderer.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/Renderer.java @@ -30,5 +30,5 @@ public abstract class Renderer { } public abstract void render(OntologyQuery ontologyQuery, List dataModels, - SqlValidatorScope scope, S2CalciteSchema schema, boolean nonAgg) throws Exception; + SqlValidatorScope scope, S2CalciteSchema schema) throws Exception; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/SourceRender.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/SourceRender.java deleted file mode 100644 index 59c6e671d..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/render/SourceRender.java +++ /dev/null @@ -1,80 +0,0 @@ -package com.tencent.supersonic.headless.core.translator.parser.calcite.render; - -import com.tencent.supersonic.common.pojo.enums.EngineType; -import com.tencent.supersonic.headless.api.pojo.Identify; -import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; -import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; -import com.tencent.supersonic.headless.core.pojo.DataModel; -import com.tencent.supersonic.headless.core.pojo.OntologyQuery; -import com.tencent.supersonic.headless.core.translator.parser.calcite.S2CalciteSchema; -import com.tencent.supersonic.headless.core.translator.parser.calcite.TableView; -import com.tencent.supersonic.headless.core.translator.parser.calcite.node.DataModelNode; -import com.tencent.supersonic.headless.core.translator.parser.calcite.node.SemanticNode; -import lombok.extern.slf4j.Slf4j; -import org.apache.calcite.sql.validate.SqlValidatorScope; - -import java.util.List; -import java.util.Optional; -import java.util.Set; - -/** process the table dataSet from the defined data model schema */ -@Slf4j -public class SourceRender extends Renderer { - - public static TableView renderOne(Set queryMetrics, - Set queryDimensions, DataModel dataModel, SqlValidatorScope scope, - S2CalciteSchema schema) { - TableView tableView = new TableView(); - EngineType engineType = schema.getOntology().getDatabase().getType(); - Set queryFields = tableView.getFields(); - queryMetrics.stream().forEach(m -> queryFields.addAll(m.getFields())); - queryDimensions.stream().forEach(m -> queryFields.add(m.getBizName())); - - try { - for (String field : queryFields) { - tableView.getSelect().add(SemanticNode.parse(field, scope, engineType)); - } - tableView.setTable(DataModelNode.build(dataModel, scope)); - } catch (Exception e) { - log.error("Failed to create sqlNode for data model {}", dataModel); - } - - return tableView; - } - - public static boolean isDimension(String name, DataModel dataModel, S2CalciteSchema schema) { - Optional dimension = dataModel.getDimensions().stream() - .filter(d -> d.getName().equalsIgnoreCase(name)).findFirst(); - if (dimension.isPresent()) { - return true; - } - Optional identify = dataModel.getIdentifiers().stream() - .filter(i -> i.getName().equalsIgnoreCase(name)).findFirst(); - if (identify.isPresent()) { - return true; - } - if (schema.getDimensions().containsKey(dataModel.getName())) { - Optional dataSourceDim = schema.getDimensions().get(dataModel.getName()) - .stream().filter(d -> d.getName().equalsIgnoreCase(name)).findFirst(); - if (dataSourceDim.isPresent()) { - return true; - } - } - return false; - } - - - public void render(OntologyQuery ontologyQuery, List dataModels, - SqlValidatorScope scope, S2CalciteSchema schema, boolean nonAgg) throws Exception { - if (dataModels.size() == 1) { - DataModel dataModel = dataModels.get(0); - tableView = renderOne(ontologyQuery.getMetrics(), ontologyQuery.getDimensions(), - dataModel, scope, schema); - } else { - JoinRender joinRender = new JoinRender(); - joinRender.render(ontologyQuery, dataModels, scope, schema, nonAgg); - tableView = joinRender.getTableView(); - } - } - -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java index b6b41c830..0cda1e84c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java @@ -141,7 +141,12 @@ public class SqlGenerateUtils { String whereClauseFromFilter = sqlFilterUtils.getWhereClause(structQuery.getDimensionFilters()); String whereFromDate = getDateWhereClause(structQuery.getDateInfo(), itemDateResp); - return mergeDateWhereClause(structQuery, whereClauseFromFilter, whereFromDate); + String mergedWhere = + mergeDateWhereClause(structQuery, whereClauseFromFilter, whereFromDate); + if (StringUtils.isNotBlank(mergedWhere)) { + mergedWhere = "where " + mergedWhere; + } + return mergedWhere; } private String mergeDateWhereClause(StructQuery structQuery, String whereClauseFromFilter, diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index 81a129bb2..ff8414e3b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -109,7 +109,7 @@ public class ModelServiceImpl implements ModelService { @Lazy DimensionService dimensionService, @Lazy MetricService metricService, DomainService domainService, UserService userService, DataSetService dataSetService, DateInfoRepository dateInfoRepository, ModelRelaService modelRelaService, - ApplicationEventPublisher eventPublisher) { + ApplicationEventPublisher eventPublisher) { this.modelRepository = modelRepository; this.databaseService = databaseService; this.dimensionService = dimensionService; diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java index a1d9b6d5f..f7b1ec2f7 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java @@ -22,10 +22,10 @@ import com.tencent.supersonic.headless.server.utils.ModelConverter; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.mockito.Mockito; +import org.springframework.context.ApplicationEventPublisher; import java.util.ArrayList; import java.util.List; -import org.springframework.context.ApplicationEventPublisher; import static org.mockito.Mockito.when; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java index d2e01520e..fe61d7666 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java @@ -83,6 +83,7 @@ public class MetricTest extends BaseTest { } @Test + @SetSystemProperty(key = "s2.test", value = "true") public void testMetricFilter() throws Exception { QueryResult actualResult = submitNewChat("alice的访问次数", agent.getId()); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java index 8024bb565..1bbd8a3ca 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java @@ -18,6 +18,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO; import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository; +import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.SchemaService; import com.tencent.supersonic.util.DataUtils; import org.apache.commons.collections.CollectionUtils; @@ -40,6 +41,8 @@ public class BaseTest extends BaseApplication { protected SchemaService schemaService; @Autowired private AgentService agentService; + @Autowired + protected DatabaseService databaseService; protected Agent agent; protected SemanticSchema schema; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java deleted file mode 100644 index c202b4b68..000000000 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java +++ /dev/null @@ -1,53 +0,0 @@ -package com.tencent.supersonic.headless; - -import com.tencent.supersonic.common.pojo.User; -import com.tencent.supersonic.demo.S2VisitsDemo; -import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; -import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; -import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junitpioneer.jupiter.SetSystemProperty; - -import java.util.Collections; -import java.util.Optional; - -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -public class TranslateTest extends BaseTest { - - private Long dataSetId; - - @BeforeEach - public void init() { - agent = getAgentByName(S2VisitsDemo.AGENT_NAME); - schema = schemaService.getSemanticSchema(agent.getDataSetIds()); - Optional id = agent.getDataSetIds().stream().findFirst(); - dataSetId = id.orElse(1L); - } - - @Test - public void testSqlExplain() throws Exception { - String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; - SemanticTranslateResp explain = semanticLayerService - .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); - assertNotNull(explain); - assertNotNull(explain.getQuerySQL()); - assertTrue(explain.getQuerySQL().contains("department")); - assertTrue(explain.getQuerySQL().contains("pv")); - } - - @Test - @SetSystemProperty(key = "s2.test", value = "true") - public void testStructExplain() throws Exception { - QueryStructReq queryStructReq = - buildQueryStructReq(Collections.singletonList("department")); - SemanticTranslateResp explain = - semanticLayerService.translate(queryStructReq, User.getDefaultUser()); - assertNotNull(explain); - assertNotNull(explain.getQuerySQL()); - assertTrue(explain.getQuerySQL().contains("department")); - assertTrue(explain.getQuerySQL().contains("stay_hours")); - } -} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java new file mode 100644 index 000000000..25cae48bc --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java @@ -0,0 +1,94 @@ +package com.tencent.supersonic.headless; + +import com.tencent.supersonic.common.pojo.User; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.demo.S2VisitsDemo; +import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; +import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; +import org.apache.commons.lang3.StringUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junitpioneer.jupiter.SetSystemProperty; + +import java.util.Optional; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class TranslatorTest extends BaseTest { + + private Long dataSetId; + + private DatabaseResp databaseResp; + + @BeforeEach + public void init() { + agent = getAgentByName(S2VisitsDemo.AGENT_NAME); + schema = schemaService.getSemanticSchema(agent.getDataSetIds()); + Optional id = agent.getDataSetIds().stream().findFirst(); + dataSetId = id.orElse(1L); + databaseResp = databaseService.getDatabase(1L); + } + + private void executeSql(String sql) { + SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp); + assert StringUtils.isBlank(queryResp.getErrorMsg()); + System.out.println( + String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList()))); + } + + @Test + public void testSql() throws Exception { + String sql = + "SELECT SUM(访问次数) AS _总访问次数_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15'"; + SemanticTranslateResp explain = semanticLayerService + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + assertNotNull(explain); + assertNotNull(explain.getQuerySQL()); + assertTrue(explain.getQuerySQL().contains("count(imp_date)")); + executeSql(explain.getQuerySQL()); + } + + @Test + public void testSql_1() throws Exception { + String sql = "SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; + SemanticTranslateResp explain = semanticLayerService + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + assertNotNull(explain); + assertNotNull(explain.getQuerySQL()); + assertTrue(explain.getQuerySQL().contains("department")); + assertTrue(explain.getQuerySQL().contains("count(imp_date)")); + executeSql(explain.getQuerySQL()); + } + + @Test + @SetSystemProperty(key = "s2.test", value = "true") + public void testSql_2() throws Exception { + String sql = + "WITH _department_visits_ AS (SELECT 部门, SUM(访问次数) AS _total_visits_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15' GROUP BY 部门) SELECT 部门 FROM _department_visits_ ORDER BY _total_visits_ DESC LIMIT 2"; + SemanticTranslateResp explain = semanticLayerService + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + assertNotNull(explain); + assertNotNull(explain.getQuerySQL()); + assertTrue(explain.getQuerySQL().toLowerCase().contains("department")); + assertTrue(explain.getQuerySQL().toLowerCase().contains("count(imp_date)")); + executeSql(explain.getQuerySQL()); + } + + @Test + @SetSystemProperty(key = "s2.test", value = "true") + public void testSql_3() throws Exception { + String sql = + "WITH recent_data AS (SELECT 用户名, 访问次数 FROM 超音数数据集 WHERE 部门 = 'marketing' AND 数据日期 >= '2024-12-01' AND 数据日期 <= '2024-12-15') SELECT 用户名 FROM recent_data ORDER BY 访问次数 DESC LIMIT 1"; + SemanticTranslateResp explain = semanticLayerService + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + assertNotNull(explain); + assertNotNull(explain.getQuerySQL()); + assertTrue(explain.getQuerySQL().toLowerCase().contains("department")); + assertTrue(explain.getQuerySQL().toLowerCase().contains("count(imp_date)")); + executeSql(explain.getQuerySQL()); + } + +}