[improvement][headless]Enhance translation of derived metrics and refactor translator code.

This commit is contained in:
jerryjzhang
2024-12-15 11:37:16 +08:00
parent 14087825df
commit ed5c129a4a
21 changed files with 218 additions and 445 deletions

View File

@@ -594,7 +594,14 @@ public class SqlReplaceHelper {
Select selectStatement = SqlSelectHelper.getSelect(sql);
List<PlainSelect> plainSelectList = new ArrayList<>();
if (selectStatement instanceof PlainSelect) {
plainSelectList.add((PlainSelect) selectStatement);
// if with statement exists, replace expression in the with statement.
if (!CollectionUtils.isEmpty(selectStatement.getWithItemsList())) {
selectStatement.getWithItemsList().forEach(withItem -> {
plainSelectList.add(withItem.getSelect().getPlainSelect());
});
} else {
plainSelectList.add((PlainSelect) selectStatement);
}
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
@@ -606,9 +613,13 @@ public class SqlReplaceHelper {
} else {
return sql;
}
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelects(plainSelectList);
for (PlainSelect plainSelect : plainSelects) {
replacePlainSelectByExpr(plainSelect, replace);
if (SqlSelectHelper.hasAggregateFunction(plainSelect)) {
SqlSelectHelper.addMissingGroupby(plainSelect);
}
}
return selectStatement.toString();
}

View File

@@ -914,4 +914,31 @@ public class SqlSelectHelper {
}
});
}
public static void addMissingGroupby(PlainSelect plainSelect) {
if (Objects.nonNull(plainSelect.getGroupBy())
&& !plainSelect.getGroupBy().getGroupByExpressionList().isEmpty()) {
return;
}
GroupByElement groupBy = new GroupByElement();
for (SelectItem selectItem : plainSelect.getSelectItems()) {
Expression expression = selectItem.getExpression();
if (expression instanceof Column) {
groupBy.addGroupByExpression(expression);
}
}
if (!groupBy.getGroupByExpressionList().isEmpty()) {
plainSelect.setGroupByElement(groupBy);
}
}
public static boolean hasAggregateFunction(PlainSelect plainSelect) {
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
FunctionVisitor visitor = new FunctionVisitor();
for (SelectItem selectItem : selectItems) {
selectItem.accept(visitor);
}
return !visitor.getFunctionNames().isEmpty();
}
}

View File

@@ -4,7 +4,6 @@ import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ColumnOrder;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import lombok.Data;
@@ -18,16 +17,11 @@ public class OntologyQuery {
private Set<MetricSchemaResp> metrics = Sets.newHashSet();
private Set<DimSchemaResp> dimensions = Sets.newHashSet();
private Set<String> fields = Sets.newHashSet();
private String where;
private Long limit;
private List<ColumnOrder> order;
private boolean nativeQuery = true;
private AggOption aggOption = AggOption.NATIVE;
public boolean hasDerivedMetric() {
return metrics.stream().anyMatch(MetricResp::isDerived);
}
public Set<MetricSchemaResp> getMetricsByModel(Long modelId) {
return metrics.stream().filter(m -> m.getModelId().equals(modelId))
.collect(Collectors.toSet());

View File

@@ -44,6 +44,7 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
for (QueryOptimizer queryOptimizer : ComponentFactory.getQueryOptimizers()) {
queryOptimizer.rewrite(queryStatement);
}
log.info("translated query SQL: [{}]", queryStatement.getSql());
} catch (Exception e) {
queryStatement.setErrMsg(e.getMessage());
log.error("Failed to translate query [{}]", e.getMessage(), e);

View File

@@ -48,9 +48,6 @@ public class MetricExpressionConverter implements QueryConverter {
private Map<String, String> getMetricExpressions(SemanticSchemaResp semanticSchema,
OntologyQuery ontologyQuery) {
if (!ontologyQuery.hasDerivedMetric()) {
return Collections.emptyMap();
}
List<MetricSchemaResp> allMetrics = semanticSchema.getMetrics();
List<DimSchemaResp> allDimensions = semanticSchema.getDimensions();
@@ -73,14 +70,14 @@ public class MetricExpressionConverter implements QueryConverter {
Map<String, String> visitedMetrics = new HashMap<>();
Map<String, String> metric2Expr = new HashMap<>();
for (MetricSchemaResp queryMetric : queryMetrics) {
String fieldExpr = buildFieldExpr(allMetrics, allFields, allMeasures, allDimensions,
queryMetric.getExpr(), queryMetric.getMetricDefineType(), aggOption,
visitedMetrics, queryDimensions, queryFields);
// add all fields referenced in the expression
queryMetric.getFields().addAll(SqlSelectHelper.getFieldsFromExpr(fieldExpr));
log.debug("derived metric {}->{}", queryMetric.getBizName(), fieldExpr);
if (queryMetric.isDerived()) {
String fieldExpr = buildFieldExpr(allMetrics, allFields, allMeasures, allDimensions,
queryMetric.getExpr(), queryMetric.getMetricDefineType(), aggOption,
visitedMetrics, queryDimensions, queryFields);
metric2Expr.put(queryMetric.getBizName(), fieldExpr);
// add all fields referenced in the expression
queryMetric.getFields().addAll(SqlSelectHelper.getFieldsFromExpr(fieldExpr));
log.debug("derived metric {}->{}", queryMetric.getBizName(), fieldExpr);
}
}
@@ -120,13 +117,13 @@ public class MetricExpressionConverter implements QueryConverter {
Measure measure = allMeasures.get(field);
if (AggOperatorEnum.COUNT_DISTINCT.getOperator()
.equalsIgnoreCase(measure.getAgg())) {
return AggOption.NATIVE.equals(aggOption) ? measure.getBizName()
return AggOption.NATIVE.equals(aggOption) ? measure.getExpr()
: AggOperatorEnum.COUNT.getOperator() + " ( "
+ AggOperatorEnum.DISTINCT + " "
+ measure.getBizName() + " ) ";
+ AggOperatorEnum.DISTINCT + " " + measure.getExpr()
+ " ) ";
}
String expr = AggOption.NATIVE.equals(aggOption) ? measure.getBizName()
: measure.getAgg() + " ( " + measure.getBizName() + " ) ";
String expr = AggOption.NATIVE.equals(aggOption) ? measure.getExpr()
: measure.getAgg() + " ( " + measure.getExpr() + " ) ";
replace.put(field, expr);
}

View File

@@ -71,7 +71,6 @@ public class SqlQueryConverter implements QueryConverter {
} else if (sqlQueryAggOption.equals(AggOption.NATIVE) && !queryMetrics.isEmpty()) {
ontologyQuery.setAggOption(AggOption.DEFAULT);
}
ontologyQuery.setNativeQuery(!AggOption.isAgg(ontologyQuery.getAggOption()));
queryStatement.setOntologyQuery(ontologyQuery);
log.info("parse sqlQuery [{}] ", sqlQuery);

View File

@@ -35,8 +35,9 @@ public class StructQueryConverter implements QueryConverter {
}
SqlQuery sqlQuery = new SqlQuery();
sqlQuery.setTable(dsTable);
String sql = String.format("select %s from %s %s %s %s",
String sql = String.format("select %s from %s %s %s %s %s",
sqlGenerateUtils.getSelect(structQuery), dsTable,
sqlGenerateUtils.generateWhere(structQuery, null),
sqlGenerateUtils.getGroupBy(structQuery), sqlGenerateUtils.getOrderBy(structQuery),
sqlGenerateUtils.getLimit(structQuery));
Database database = queryStatement.getOntology().getDatabase();

View File

@@ -2,16 +2,14 @@ package com.tencent.supersonic.headless.core.translator.parser.calcite;
import com.tencent.supersonic.common.calcite.Configuration;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
import com.tencent.supersonic.headless.core.pojo.DataModel;
import com.tencent.supersonic.headless.core.pojo.Database;
import com.tencent.supersonic.headless.core.pojo.OntologyQuery;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.DataModelNode;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.SemanticNode;
import com.tencent.supersonic.headless.core.translator.parser.calcite.render.JoinRender;
import com.tencent.supersonic.headless.core.translator.parser.calcite.render.Renderer;
import com.tencent.supersonic.headless.core.translator.parser.calcite.render.SourceRender;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Constants;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParser;
@@ -29,8 +27,6 @@ public class SqlBuilder {
private OntologyQuery ontologyQuery;
private SqlValidatorScope scope;
private SqlNode parserNode;
private boolean isAgg = false;
private AggOption aggOption = AggOption.DEFAULT;
public SqlBuilder(S2CalciteSchema schema) {
this.schema = schema;
@@ -41,7 +37,6 @@ public class SqlBuilder {
if (ontologyQuery.getLimit() == null) {
ontologyQuery.setLimit(0L);
}
this.aggOption = ontologyQuery.getAggOption();
buildParseNode();
Database database = queryStatement.getOntology().getDatabase();
@@ -57,41 +52,25 @@ public class SqlBuilder {
if (dataModels == null || dataModels.isEmpty()) {
throw new Exception("data model not found");
}
isAgg = getAgg(dataModels.get(0));
// build level by level
LinkedList<Renderer> builders = new LinkedList<>();
builders.add(new SourceRender());
builders.add(new JoinRender());
ListIterator<Renderer> it = builders.listIterator();
int i = 0;
Renderer previous = null;
while (it.hasNext()) {
Renderer renderer = it.next();
if (previous != null) {
previous.render(ontologyQuery, dataModels, scope, schema, !isAgg);
previous.render(ontologyQuery, dataModels, scope, schema);
renderer.setTable(previous.builderAs(DataModelNode.getNames(dataModels) + "_" + i));
i++;
}
previous = renderer;
}
builders.getLast().render(ontologyQuery, dataModels, scope, schema, !isAgg);
builders.getLast().render(ontologyQuery, dataModels, scope, schema);
parserNode = builders.getLast().build();
}
private boolean getAgg(DataModel dataModel) {
if (!AggOption.DEFAULT.equals(aggOption)) {
return AggOption.isAgg(aggOption);
}
// default by dataModel time aggregation
if (Objects.nonNull(dataModel.getAggTime()) && !dataModel.getAggTime()
.equalsIgnoreCase(Constants.DIMENSION_TYPE_TIME_GRANULARITY_NONE)) {
if (!ontologyQuery.isNativeQuery()) {
return true;
}
}
return isAgg;
}
public String getSql(EngineType engineType) {
return SemanticNode.getSql(parserNode, engineType);
}

View File

@@ -4,13 +4,14 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.headless.core.pojo.DataModel;
import lombok.Data;
import org.apache.calcite.sql.*;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.parser.SqlParserPos;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/** basic query project */
@Data

View File

@@ -9,7 +9,6 @@ import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.core.pojo.DataModel;
import com.tencent.supersonic.headless.core.pojo.JoinRelation;
import com.tencent.supersonic.headless.core.pojo.Ontology;
@@ -101,19 +100,6 @@ public class DataModelNode extends SemanticNode {
dateInfo, dimensions, metrics);
}
public static SqlNode buildExtend(DataModel datasource, Map<String, String> exprList,
SqlValidatorScope scope) throws Exception {
if (CollectionUtils.isEmpty(exprList)) {
return build(datasource, scope);
}
EngineType engineType = EngineType.fromString(datasource.getType());
SqlNode dataSet = new SqlBasicCall(new LateralViewExplodeNode(exprList),
Arrays.asList(build(datasource, scope), new SqlNodeList(
getExtendField(exprList, scope, engineType), SqlParserPos.ZERO)),
SqlParserPos.ZERO);
return buildAs(datasource.getName() + Constants.DIMENSION_ARRAY_SINGLE_SUFFIX, dataSet);
}
public static List<SqlNode> getExtendField(Map<String, String> exprList,
SqlValidatorScope scope, EngineType engineType) throws Exception {
List<SqlNode> sqlNodeList = new ArrayList<>();
@@ -153,32 +139,6 @@ public class DataModelNode extends SemanticNode {
});
}
public static void mergeQueryFilterDimensionMeasure(Ontology ontology,
OntologyQuery ontologyQuery, Set<String> dimensions, Set<String> measures,
SqlValidatorScope scope) throws Exception {
EngineType engineType = ontology.getDatabase().getType();
if (Objects.nonNull(ontologyQuery.getWhere()) && !ontologyQuery.getWhere().isEmpty()) {
Set<String> filterConditions = new HashSet<>();
FilterNode.getFilterField(parse(ontologyQuery.getWhere(), scope, engineType),
filterConditions);
Set<String> queryMeasures = new HashSet<>(measures);
Set<String> schemaMetricName = ontology.getMetrics().stream()
.map(MetricSchemaResp::getName).collect(Collectors.toSet());
for (String filterCondition : filterConditions) {
if (schemaMetricName.contains(filterCondition)) {
ontology.getMetrics().stream()
.filter(m -> m.getName().equalsIgnoreCase(filterCondition))
.forEach(m -> m.getMetricDefineByMeasureParams().getMeasures()
.forEach(mm -> queryMeasures.add(mm.getName())));
continue;
}
dimensions.add(filterCondition);
}
measures.clear();
measures.addAll(queryMeasures);
}
}
public static List<DataModel> getQueryDataModelsV2(Ontology ontology, OntologyQuery query) {
// first, sort models based on the number of query metrics
Map<String, Integer> modelMetricCount = Maps.newHashMap();
@@ -209,8 +169,8 @@ public class DataModelNode extends SemanticNode {
.collect(Collectors.toList());
Set<String> dataModelNames = Sets.newLinkedHashSet();
dataModelNames.addAll(metricsDataModels);
dataModelNames.addAll(dimDataModels);
dataModelNames.addAll(metricsDataModels);
return dataModelNames.stream().map(bizName -> ontology.getDataModelMap().get(bizName))
.collect(Collectors.toList());
}
@@ -289,45 +249,6 @@ public class DataModelNode extends SemanticNode {
return dataModel;
}
private static DataModel findBaseModel(Ontology ontology, Set<String> queryMeasures,
Set<String> queryDimensions) {
DataModel dataModel = null;
// first, try to find the model with the most matching measures
Map<String, Integer> dataModelMeasuresCount = new HashMap<>();
for (Map.Entry<String, DataModel> entry : ontology.getDataModelMap().entrySet()) {
Set<String> sourceMeasure = entry.getValue().getMeasures().stream()
.map(Measure::getName).collect(Collectors.toSet());
sourceMeasure.retainAll(queryMeasures);
dataModelMeasuresCount.put(entry.getKey(), sourceMeasure.size());
}
log.info("dataModelMeasureCount: [{}]", dataModelMeasuresCount);
Optional<Map.Entry<String, Integer>> base =
dataModelMeasuresCount.entrySet().stream().filter(e -> e.getValue() > 0)
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst();
if (base.isPresent()) {
dataModel = ontology.getDataModelMap().get(base.get().getKey());
} else {
// second, try to find the model with the most matching dimensions
Map<String, Integer> dataModelDimCount = new HashMap<>();
for (Map.Entry<String, List<DimSchemaResp>> entry : ontology.getDimensionMap()
.entrySet()) {
Set<String> modelDimensions = entry.getValue().stream().map(DimSchemaResp::getName)
.collect(Collectors.toSet());
modelDimensions.retainAll(queryDimensions);
dataModelDimCount.put(entry.getKey(), modelDimensions.size());
}
log.info("dataModelDimCount: [{}]", dataModelDimCount);
base = dataModelDimCount.entrySet().stream().filter(e -> e.getValue() > 0)
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst();
if (base.isPresent()) {
dataModel = ontology.getDataModelMap().get(base.get().getKey());
}
}
return dataModel;
}
private static boolean checkMatch(DataModel baseDataModel, Set<String> queryMeasures,
Set<String> queryDimension) {
boolean isAllMatch = true;

View File

@@ -9,10 +9,10 @@ import com.tencent.supersonic.headless.core.pojo.JoinRelation;
import com.tencent.supersonic.headless.core.pojo.OntologyQuery;
import com.tencent.supersonic.headless.core.translator.parser.calcite.S2CalciteSchema;
import com.tencent.supersonic.headless.core.translator.parser.calcite.TableView;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.DataModelNode;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.IdentifyNode;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.SemanticNode;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Constants;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Materialization;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.*;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@@ -31,7 +31,7 @@ public class JoinRender extends Renderer {
@Override
public void render(OntologyQuery ontologyQuery, List<DataModel> dataModels,
SqlValidatorScope scope, S2CalciteSchema schema, boolean nonAgg) throws Exception {
SqlValidatorScope scope, S2CalciteSchema schema) throws Exception {
SqlNode left = null;
TableView leftTable = null;
Map<String, SqlNode> outerSelect = new HashMap<>();
@@ -50,7 +50,8 @@ public class JoinRender extends Renderer {
primary.add(identify.getName());
}
TableView tableView = SourceRender.renderOne(queryMetrics, queryDimensions, dataModel, scope, schema);
TableView tableView =
renderOne(queryMetrics, queryDimensions, dataModel, scope, schema);
log.info("tableView {}", StringUtils.normalizeSpace(tableView.getTable().toString()));
String alias = Constants.JOIN_TABLE_PREFIX + dataModel.getName();
tableView.setAlias(alias);
@@ -60,14 +61,13 @@ public class JoinRender extends Renderer {
outerSelect.put(field, SemanticNode.parse(alias + "." + field, scope, engineType));
}
if (left == null) {
leftTable = tableView;
left = SemanticNode.buildAs(tableView.getAlias(), getTable(tableView));
beforeModels.put(dataModel.getName(), leftTable.getAlias());
continue;
} else {
left = buildJoin(left, leftTable, tableView, beforeModels, dataModel, schema,
scope);
}
left = buildJoin(left, leftTable, tableView, beforeModels, dataModel, schema, scope);
leftTable = tableView;
beforeModels.put(dataModel.getName(), tableView.getAlias());
beforeModels.put(dataModel.getName(), leftTable.getAlias());
}
for (Map.Entry<String, SqlNode> entry : outerSelect.entrySet()) {
@@ -84,28 +84,16 @@ public class JoinRender extends Renderer {
Map<String, String> before, DataModel dataModel, S2CalciteSchema schema,
SqlValidatorScope scope) throws Exception {
EngineType engineType = schema.getOntology().getDatabase().getType();
SqlNode condition = getCondition(leftTable, rightTable, dataModel, schema, scope, engineType);
SqlNode condition =
getCondition(leftTable, rightTable, dataModel, schema, scope, engineType);
SqlLiteral sqlLiteral = SemanticNode.getJoinSqlLiteral("");
JoinRelation matchJoinRelation = getMatchJoinRelation(before, rightTable, schema);
SqlNode joinRelationCondition = null;
SqlNode joinRelationCondition;
if (!CollectionUtils.isEmpty(matchJoinRelation.getJoinCondition())) {
sqlLiteral = SemanticNode.getJoinSqlLiteral(matchJoinRelation.getJoinType());
joinRelationCondition = getCondition(matchJoinRelation, scope, engineType);
condition = joinRelationCondition;
}
if (Materialization.TimePartType.ZIPPER.equals(leftTable.getDataModel().getTimePartType())
|| Materialization.TimePartType.ZIPPER
.equals(rightTable.getDataModel().getTimePartType())) {
SqlNode zipperCondition =
getZipperCondition(leftTable, rightTable, dataModel, schema, scope);
if (Objects.nonNull(joinRelationCondition)) {
condition = new SqlBasicCall(SqlStdOperatorTable.AND,
new ArrayList<>(Arrays.asList(zipperCondition, joinRelationCondition)),
SqlParserPos.ZERO, null);
} else {
condition = zipperCondition;
}
}
return new SqlJoin(SqlParserPos.ZERO, leftNode,
SqlLiteral.createBoolean(false, SqlParserPos.ZERO), sqlLiteral,
@@ -172,7 +160,7 @@ public class JoinRender extends Renderer {
selectLeft.retainAll(selectRight);
SqlNode condition = null;
for (String on : selectLeft) {
if (!SourceRender.isDimension(on, dataModel, schema)) {
if (!isDimension(on, dataModel, schema)) {
continue;
}
if (IdentifyNode.isForeign(on, left.getDataModel().getIdentifiers())) {
@@ -202,104 +190,47 @@ public class JoinRender extends Renderer {
return condition;
}
private static void joinOrder(int cnt, String id, Map<String, Set<String>> next,
Queue<String> orders, Map<String, Boolean> visited) {
visited.put(id, true);
orders.add(id);
if (orders.size() >= cnt) {
return;
public static TableView renderOne(Set<MetricSchemaResp> queryMetrics,
Set<DimSchemaResp> queryDimensions, DataModel dataModel, SqlValidatorScope scope,
S2CalciteSchema schema) {
TableView tableView = new TableView();
EngineType engineType = schema.getOntology().getDatabase().getType();
Set<String> queryFields = tableView.getFields();
queryMetrics.stream().forEach(m -> queryFields.addAll(m.getFields()));
queryDimensions.stream().forEach(m -> queryFields.add(m.getBizName()));
try {
for (String field : queryFields) {
tableView.getSelect().add(SemanticNode.parse(field, scope, engineType));
}
tableView.setTable(DataModelNode.build(dataModel, scope));
} catch (Exception e) {
log.error("Failed to create sqlNode for data model {}", dataModel);
}
for (String nextId : next.get(id)) {
if (!visited.get(nextId)) {
joinOrder(cnt, nextId, next, orders, visited);
if (orders.size() >= cnt) {
return;
}
return tableView;
}
public static boolean isDimension(String name, DataModel dataModel, S2CalciteSchema schema) {
Optional<DimSchemaResp> dimension = dataModel.getDimensions().stream()
.filter(d -> d.getName().equalsIgnoreCase(name)).findFirst();
if (dimension.isPresent()) {
return true;
}
Optional<Identify> identify = dataModel.getIdentifiers().stream()
.filter(i -> i.getName().equalsIgnoreCase(name)).findFirst();
if (identify.isPresent()) {
return true;
}
if (schema.getDimensions().containsKey(dataModel.getName())) {
Optional<DimSchemaResp> dataSourceDim = schema.getDimensions().get(dataModel.getName())
.stream().filter(d -> d.getName().equalsIgnoreCase(name)).findFirst();
if (dataSourceDim.isPresent()) {
return true;
}
}
orders.poll();
visited.put(id, false);
return false;
}
private void addZipperField(DataModel dataModel, List<String> fields) {
// if (Materialization.TimePartType.ZIPPER.equals(dataModel.getTimePartType())) {
// dataModel.getDimensions().stream()
// .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType()))
// .forEach(t -> {
// if (t.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_END)
// && !fields.contains(t.getName())) {
// fields.add(t.getName());
// }
// if (t.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_START)
// && !fields.contains(t.getName())) {
// fields.add(t.getName());
// }
// });
// }
}
private SqlNode getZipperCondition(TableView left, TableView right, DataModel dataModel,
S2CalciteSchema schema, SqlValidatorScope scope) throws Exception {
// if (Materialization.TimePartType.ZIPPER.equals(left.getDataModel().getTimePartType())
// && Materialization.TimePartType.ZIPPER
// .equals(right.getDataModel().getTimePartType())) {
// throw new Exception("not support two zipper table");
// }
SqlNode condition = null;
// Optional<Dimension> leftTime = left.getDataModel().getDimensions().stream()
// .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType()))
// .findFirst();
// Optional<Dimension> rightTime = right.getDataModel().getDimensions().stream()
// .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType()))
// .findFirst();
// if (leftTime.isPresent() && rightTime.isPresent()) {
//
// String startTime = "";
// String endTime = "";
// String dateTime = "";
//
// Optional<Dimension> startTimeOp = (Materialization.TimePartType.ZIPPER
// .equals(left.getDataModel().getTimePartType()) ? left : right).getDataModel()
// .getDimensions().stream()
// .filter(d -> Constants.DIMENSION_TYPE_TIME
// .equalsIgnoreCase(d.getType()))
// .filter(d -> d.getName()
// .startsWith(Constants.MATERIALIZATION_ZIPPER_START))
// .findFirst();
// Optional<Dimension> endTimeOp = (Materialization.TimePartType.ZIPPER
// .equals(left.getDataModel().getTimePartType()) ? left : right).getDataModel()
// .getDimensions().stream()
// .filter(d -> Constants.DIMENSION_TYPE_TIME
// .equalsIgnoreCase(d.getType()))
// .filter(d -> d.getName()
// .startsWith(Constants.MATERIALIZATION_ZIPPER_END))
// .findFirst();
// if (startTimeOp.isPresent() && endTimeOp.isPresent()) {
// TableView zipper = Materialization.TimePartType.ZIPPER
// .equals(left.getDataModel().getTimePartType()) ? left : right;
// TableView partMetric = Materialization.TimePartType.ZIPPER
// .equals(left.getDataModel().getTimePartType()) ? right : left;
// Optional<Dimension> partTime = Materialization.TimePartType.ZIPPER
// .equals(left.getDataModel().getTimePartType()) ? rightTime : leftTime;
// startTime = zipper.getAlias() + "." + startTimeOp.get().getName();
// endTime = zipper.getAlias() + "." + endTimeOp.get().getName();
// dateTime = partMetric.getAlias() + "." + partTime.get().getName();
// }
// EngineType engineType = schema.getOntology().getDatabase().getType();
// ArrayList<SqlNode> operandList =
// new ArrayList<>(Arrays.asList(SemanticNode.parse(endTime, scope, engineType),
// SemanticNode.parse(dateTime, scope, engineType)));
// condition = new SqlBasicCall(SqlStdOperatorTable.AND,
// new ArrayList<SqlNode>(Arrays.asList(
// new SqlBasicCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
// new ArrayList<SqlNode>(Arrays.asList(
// SemanticNode.parse(startTime, scope, engineType),
// SemanticNode.parse(dateTime, scope, engineType))),
// SqlParserPos.ZERO, null),
// new SqlBasicCall(SqlStdOperatorTable.GREATER_THAN, operandList,
// SqlParserPos.ZERO, null))),
// SqlParserPos.ZERO, null);
// }
return condition;
}
}

View File

@@ -1,59 +0,0 @@
package com.tencent.supersonic.headless.core.translator.parser.calcite.render;
import com.tencent.supersonic.common.pojo.ColumnOrder;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.core.pojo.DataModel;
import com.tencent.supersonic.headless.core.pojo.OntologyQuery;
import com.tencent.supersonic.headless.core.translator.parser.calcite.S2CalciteSchema;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.MetricNode;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.SemanticNode;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
/** process the query result items from query request */
public class OutputRender extends Renderer {
@Override
public void render(OntologyQuery ontologyQuery, List<DataModel> dataModels,
SqlValidatorScope scope, S2CalciteSchema schema, boolean nonAgg) throws Exception {
EngineType engineType = schema.getOntology().getDatabase().getType();
for (DimSchemaResp dimension : ontologyQuery.getDimensions()) {
tableView.getMetric().add(SemanticNode.parse(dimension.getExpr(), scope, engineType));
}
for (MetricSchemaResp metric : ontologyQuery.getMetrics()) {
if (MetricNode.isMetricField(metric.getName(), schema)) {
// metric from field ignore
continue;
}
tableView.getMetric().add(SemanticNode.parse(metric.getName(), scope, engineType));
}
if (ontologyQuery.getLimit() > 0) {
SqlNode offset =
SemanticNode.parse(ontologyQuery.getLimit().toString(), scope, engineType);
tableView.setOffset(offset);
}
if (!CollectionUtils.isEmpty(ontologyQuery.getOrder())) {
List<SqlNode> orderList = new ArrayList<>();
for (ColumnOrder columnOrder : ontologyQuery.getOrder()) {
if (SqlStdOperatorTable.DESC.getName().equalsIgnoreCase(columnOrder.getOrder())) {
orderList.add(SqlStdOperatorTable.DESC.createCall(SqlParserPos.ZERO,
new SqlNode[] {SemanticNode.parse(columnOrder.getCol(), scope,
engineType)}));
} else {
orderList.add(SemanticNode.parse(columnOrder.getCol(), scope, engineType));
}
}
tableView.setOrder(new SqlNodeList(orderList, SqlParserPos.ZERO));
}
}
}

View File

@@ -30,5 +30,5 @@ public abstract class Renderer {
}
public abstract void render(OntologyQuery ontologyQuery, List<DataModel> dataModels,
SqlValidatorScope scope, S2CalciteSchema schema, boolean nonAgg) throws Exception;
SqlValidatorScope scope, S2CalciteSchema schema) throws Exception;
}

View File

@@ -1,80 +0,0 @@
package com.tencent.supersonic.headless.core.translator.parser.calcite.render;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.core.pojo.DataModel;
import com.tencent.supersonic.headless.core.pojo.OntologyQuery;
import com.tencent.supersonic.headless.core.translator.parser.calcite.S2CalciteSchema;
import com.tencent.supersonic.headless.core.translator.parser.calcite.TableView;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.DataModelNode;
import com.tencent.supersonic.headless.core.translator.parser.calcite.node.SemanticNode;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import java.util.List;
import java.util.Optional;
import java.util.Set;
/** process the table dataSet from the defined data model schema */
@Slf4j
public class SourceRender extends Renderer {
public static TableView renderOne(Set<MetricSchemaResp> queryMetrics,
Set<DimSchemaResp> queryDimensions, DataModel dataModel, SqlValidatorScope scope,
S2CalciteSchema schema) {
TableView tableView = new TableView();
EngineType engineType = schema.getOntology().getDatabase().getType();
Set<String> queryFields = tableView.getFields();
queryMetrics.stream().forEach(m -> queryFields.addAll(m.getFields()));
queryDimensions.stream().forEach(m -> queryFields.add(m.getBizName()));
try {
for (String field : queryFields) {
tableView.getSelect().add(SemanticNode.parse(field, scope, engineType));
}
tableView.setTable(DataModelNode.build(dataModel, scope));
} catch (Exception e) {
log.error("Failed to create sqlNode for data model {}", dataModel);
}
return tableView;
}
public static boolean isDimension(String name, DataModel dataModel, S2CalciteSchema schema) {
Optional<DimSchemaResp> dimension = dataModel.getDimensions().stream()
.filter(d -> d.getName().equalsIgnoreCase(name)).findFirst();
if (dimension.isPresent()) {
return true;
}
Optional<Identify> identify = dataModel.getIdentifiers().stream()
.filter(i -> i.getName().equalsIgnoreCase(name)).findFirst();
if (identify.isPresent()) {
return true;
}
if (schema.getDimensions().containsKey(dataModel.getName())) {
Optional<DimSchemaResp> dataSourceDim = schema.getDimensions().get(dataModel.getName())
.stream().filter(d -> d.getName().equalsIgnoreCase(name)).findFirst();
if (dataSourceDim.isPresent()) {
return true;
}
}
return false;
}
public void render(OntologyQuery ontologyQuery, List<DataModel> dataModels,
SqlValidatorScope scope, S2CalciteSchema schema, boolean nonAgg) throws Exception {
if (dataModels.size() == 1) {
DataModel dataModel = dataModels.get(0);
tableView = renderOne(ontologyQuery.getMetrics(), ontologyQuery.getDimensions(),
dataModel, scope, schema);
} else {
JoinRender joinRender = new JoinRender();
joinRender.render(ontologyQuery, dataModels, scope, schema, nonAgg);
tableView = joinRender.getTableView();
}
}
}

View File

@@ -141,7 +141,12 @@ public class SqlGenerateUtils {
String whereClauseFromFilter =
sqlFilterUtils.getWhereClause(structQuery.getDimensionFilters());
String whereFromDate = getDateWhereClause(structQuery.getDateInfo(), itemDateResp);
return mergeDateWhereClause(structQuery, whereClauseFromFilter, whereFromDate);
String mergedWhere =
mergeDateWhereClause(structQuery, whereClauseFromFilter, whereFromDate);
if (StringUtils.isNotBlank(mergedWhere)) {
mergedWhere = "where " + mergedWhere;
}
return mergedWhere;
}
private String mergeDateWhereClause(StructQuery structQuery, String whereClauseFromFilter,

View File

@@ -109,7 +109,7 @@ public class ModelServiceImpl implements ModelService {
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
DomainService domainService, UserService userService, DataSetService dataSetService,
DateInfoRepository dateInfoRepository, ModelRelaService modelRelaService,
ApplicationEventPublisher eventPublisher) {
ApplicationEventPublisher eventPublisher) {
this.modelRepository = modelRepository;
this.databaseService = databaseService;
this.dimensionService = dimensionService;

View File

@@ -22,10 +22,10 @@ import com.tencent.supersonic.headless.server.utils.ModelConverter;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.context.ApplicationEventPublisher;
import java.util.ArrayList;
import java.util.List;
import org.springframework.context.ApplicationEventPublisher;
import static org.mockito.Mockito.when;

View File

@@ -83,6 +83,7 @@ public class MetricTest extends BaseTest {
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testMetricFilter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数", agent.getId());

View File

@@ -18,6 +18,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository;
import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.util.DataUtils;
import org.apache.commons.collections.CollectionUtils;
@@ -40,6 +41,8 @@ public class BaseTest extends BaseApplication {
protected SchemaService schemaService;
@Autowired
private AgentService agentService;
@Autowired
protected DatabaseService databaseService;
protected Agent agent;
protected SemanticSchema schema;

View File

@@ -1,53 +0,0 @@
package com.tencent.supersonic.headless;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.demo.S2VisitsDemo;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.SetSystemProperty;
import java.util.Collections;
import java.util.Optional;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
public class TranslateTest extends BaseTest {
private Long dataSetId;
@BeforeEach
public void init() {
agent = getAgentByName(S2VisitsDemo.AGENT_NAME);
schema = schemaService.getSemanticSchema(agent.getDataSetIds());
Optional<Long> id = agent.getDataSetIds().stream().findFirst();
dataSetId = id.orElse(1L);
}
@Test
public void testSqlExplain() throws Exception {
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getQuerySQL().contains("pv"));
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testStructExplain() throws Exception {
QueryStructReq queryStructReq =
buildQueryStructReq(Collections.singletonList("department"));
SemanticTranslateResp explain =
semanticLayerService.translate(queryStructReq, User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getQuerySQL().contains("stay_hours"));
}
}

View File

@@ -0,0 +1,94 @@
package com.tencent.supersonic.headless;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.demo.S2VisitsDemo;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.SetSystemProperty;
import java.util.Optional;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
public class TranslatorTest extends BaseTest {
private Long dataSetId;
private DatabaseResp databaseResp;
@BeforeEach
public void init() {
agent = getAgentByName(S2VisitsDemo.AGENT_NAME);
schema = schemaService.getSemanticSchema(agent.getDataSetIds());
Optional<Long> id = agent.getDataSetIds().stream().findFirst();
dataSetId = id.orElse(1L);
databaseResp = databaseService.getDatabase(1L);
}
private void executeSql(String sql) {
SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp);
assert StringUtils.isBlank(queryResp.getErrorMsg());
System.out.println(
String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList())));
}
@Test
public void testSql() throws Exception {
String sql =
"SELECT SUM(访问次数) AS _总访问次数_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15'";
SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("count(imp_date)"));
executeSql(explain.getQuerySQL());
}
@Test
public void testSql_1() throws Exception {
String sql = "SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getQuerySQL().contains("count(imp_date)"));
executeSql(explain.getQuerySQL());
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testSql_2() throws Exception {
String sql =
"WITH _department_visits_ AS (SELECT 部门, SUM(访问次数) AS _total_visits_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15' GROUP BY 部门) SELECT 部门 FROM _department_visits_ ORDER BY _total_visits_ DESC LIMIT 2";
SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().toLowerCase().contains("department"));
assertTrue(explain.getQuerySQL().toLowerCase().contains("count(imp_date)"));
executeSql(explain.getQuerySQL());
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testSql_3() throws Exception {
String sql =
"WITH recent_data AS (SELECT 用户名, 访问次数 FROM 超音数数据集 WHERE 部门 = 'marketing' AND 数据日期 >= '2024-12-01' AND 数据日期 <= '2024-12-15') SELECT 用户名 FROM recent_data ORDER BY 访问次数 DESC LIMIT 1";
SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().toLowerCase().contains("department"));
assertTrue(explain.getQuerySQL().toLowerCase().contains("count(imp_date)"));
executeSql(explain.getQuerySQL());
}
}