(improvement)(headless) add simplify optimizer (#618)

This commit is contained in:
jipeli
2024-01-11 17:11:06 +08:00
committed by GitHub
parent 3a5349c916
commit e9c7237794
10 changed files with 278 additions and 148 deletions

View File

@@ -8,17 +8,16 @@ import com.tencent.supersonic.headless.api.request.ParseSqlReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* logical parse from ParseSqlReq or MetricReq
*/
@@ -58,31 +57,16 @@ public class QueryParser {
try {
if (!CollectionUtils.isEmpty(parseSqlReq.getTables())) {
List<String[]> tables = new ArrayList<>();
String sourceId = "";
Boolean isSingleTable = parseSqlReq.getTables().size() == 1;
for (MetricTable metricTable : parseSqlReq.getTables()) {
MetricQueryReq metricReq = new MetricQueryReq();
metricReq.setMetrics(metricTable.getMetrics());
metricReq.setDimensions(metricTable.getDimensions());
metricReq.setWhere(StringUtil.formatSqlQuota(metricTable.getWhere()));
metricReq.setNativeQuery(!AggOption.isAgg(metricTable.getAggOption()));
metricReq.setRootPath(parseSqlReq.getRootPath());
QueryStatement tableSql = new QueryStatement();
tableSql.setIsS2SQL(false);
tableSql.setMetricReq(metricReq);
tableSql.setMinMaxTime(queryStatement.getMinMaxTime());
tableSql.setEnableOptimize(queryStatement.getEnableOptimize());
tableSql.setModelIds(queryStatement.getModelIds());
tableSql.setHeadlessModel(queryStatement.getHeadlessModel());
tableSql = parser(tableSql, metricTable.getAggOption());
if (!tableSql.isOk()) {
queryStatement.setErrMsg(String.format("parser table [%s] error [%s]", metricTable.getAlias(),
tableSql.getErrMsg()));
String metricTableSql = parserSql(metricTable, isSingleTable, parseSqlReq, queryStatement);
if (isSingleTable) {
queryStatement.setSql(metricTableSql);
queryStatement.setParseSqlReq(parseSqlReq);
return queryStatement;
}
tables.add(new String[]{metricTable.getAlias(), tableSql.getSql()});
sourceId = tableSql.getSourceId();
tables.add(new String[]{metricTable.getAlias(), metricTableSql});
}
if (!tables.isEmpty()) {
String sql = "";
if (parseSqlReq.isSupportWith()) {
@@ -97,7 +81,6 @@ public class QueryParser {
}
}
queryStatement.setSql(sql);
queryStatement.setSourceId(sourceId);
queryStatement.setParseSqlReq(parseSqlReq);
return queryStatement;
}
@@ -130,4 +113,32 @@ public class QueryParser {
return queryStatement;
}
private String parserSql(MetricTable metricTable, Boolean isSingleMetricTable, ParseSqlReq parseSqlReq,
QueryStatement queryStatement) throws Exception {
MetricQueryReq metricReq = new MetricQueryReq();
metricReq.setMetrics(metricTable.getMetrics());
metricReq.setDimensions(metricTable.getDimensions());
metricReq.setWhere(StringUtil.formatSqlQuota(metricTable.getWhere()));
metricReq.setNativeQuery(!AggOption.isAgg(metricTable.getAggOption()));
metricReq.setRootPath(parseSqlReq.getRootPath());
QueryStatement tableSql = new QueryStatement();
tableSql.setIsS2SQL(false);
tableSql.setMetricReq(metricReq);
tableSql.setMinMaxTime(queryStatement.getMinMaxTime());
tableSql.setEnableOptimize(queryStatement.getEnableOptimize());
tableSql.setModelIds(queryStatement.getModelIds());
tableSql.setHeadlessModel(queryStatement.getHeadlessModel());
if (isSingleMetricTable) {
tableSql.setViewSql(parseSqlReq.getSql());
tableSql.setViewAlias(metricTable.getAlias());
}
tableSql = parser(tableSql, metricTable.getAggOption());
if (!tableSql.isOk()) {
throw new Exception(String.format("parser table [%s] error [%s]", metricTable.getAlias(),
tableSql.getErrMsg()));
}
queryStatement.setSourceId(tableSql.getSourceId());
return tableSql.getSql();
}
}

View File

@@ -8,12 +8,15 @@ import com.tencent.supersonic.headless.core.parser.calcite.s2sql.HeadlessModel;
import com.tencent.supersonic.headless.core.parser.calcite.schema.HeadlessSchema;
import com.tencent.supersonic.headless.core.parser.calcite.schema.RuntimeOptions;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
/**
* the calcite parse implements
*/
@Component("CalciteSqlParser")
@Slf4j
public class CalciteSqlParser implements SqlParser {
@Override
@@ -30,6 +33,15 @@ public class CalciteSqlParser implements SqlParser {
aggBuilder.explain(queryStatement, isAgg);
queryStatement.setSql(aggBuilder.getSql());
queryStatement.setSourceId(aggBuilder.getSourceId());
if (Objects.nonNull(queryStatement.getViewAlias()) && !queryStatement.getViewAlias().isEmpty()) {
// simplify model sql with query sql
String simplifySql = aggBuilder.simplify(
getSqlByView(aggBuilder.getSql(), queryStatement.getViewSql(), queryStatement.getViewAlias()));
if (Objects.nonNull(simplifySql) && !simplifySql.isEmpty()) {
log.info("simplifySql [{}]", simplifySql);
queryStatement.setSql(simplifySql);
}
}
return queryStatement;
}
@@ -43,4 +55,8 @@ public class CalciteSqlParser implements SqlParser {
.enableOptimize(queryStatement.getEnableOptimize()).build());
return headlessSchema;
}
private String getSqlByView(String sql, String viewSql, String viewAlias) {
return String.format("with %s as (%s) %s", viewAlias, sql, viewSql);
}
}

View File

@@ -8,33 +8,24 @@ import com.tencent.supersonic.headless.core.parser.calcite.s2sql.Constants;
import com.tencent.supersonic.headless.core.parser.calcite.s2sql.DataSource;
import com.tencent.supersonic.headless.core.parser.calcite.schema.HeadlessSchema;
import com.tencent.supersonic.headless.core.parser.calcite.schema.SchemaBuilder;
import com.tencent.supersonic.headless.core.parser.calcite.schema.SemanticSqlDialect;
import com.tencent.supersonic.headless.core.parser.calcite.sql.Renderer;
import com.tencent.supersonic.headless.core.parser.calcite.sql.TableView;
import com.tencent.supersonic.headless.core.parser.calcite.sql.node.DataSourceNode;
import com.tencent.supersonic.headless.core.parser.calcite.sql.node.SemanticNode;
import com.tencent.supersonic.headless.core.parser.calcite.sql.optimizer.FilterToGroupScanRule;
import com.tencent.supersonic.headless.core.parser.calcite.sql.render.FilterRender;
import com.tencent.supersonic.headless.core.parser.calcite.sql.render.OutputRender;
import com.tencent.supersonic.headless.core.parser.calcite.sql.render.SourceRender;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Objects;
import java.util.Stack;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidatorScope;
/**
* parsing from query dimensions and metrics
@@ -125,33 +116,6 @@ public class AggPlanner implements Planner {
optimize();
}
public void optimize() {
if (Objects.isNull(schema.getRuntimeOptions()) || Objects.isNull(schema.getRuntimeOptions().getEnableOptimize())
|| !schema.getRuntimeOptions().getEnableOptimize()) {
return;
}
HepProgramBuilder hepProgramBuilder = new HepProgramBuilder();
hepProgramBuilder.addRuleInstance(new FilterToGroupScanRule(FilterToGroupScanRule.DEFAULT, schema));
RelOptPlanner relOptPlanner = new HepPlanner(hepProgramBuilder.build());
RelToSqlConverter converter = new RelToSqlConverter(SemanticSqlDialect.DEFAULT);
SqlValidator sqlValidator = Configuration.getSqlValidator(
scope.getValidator().getCatalogReader().getRootSchema());
try {
log.info("before optimize {}", SemanticNode.getSql(parserNode));
SqlToRelConverter sqlToRelConverter = Configuration.getSqlToRelConverter(scope, sqlValidator,
relOptPlanner);
RelNode sqlRel = sqlToRelConverter.convertQuery(
sqlValidator.validate(parserNode), false, true).rel;
log.debug("RelNode optimize {}", SemanticNode.getSql(converter.visitRoot(sqlRel).asStatement()));
relOptPlanner.setRoot(sqlRel);
RelNode relNode = relOptPlanner.findBestExp();
parserNode = converter.visitRoot(relNode).asStatement();
log.debug("after optimize {}", SemanticNode.getSql(parserNode));
} catch (Exception e) {
log.error("optimize error {}", e);
}
}
@Override
public String getSql() {
return SemanticNode.getSql(parserNode);
@@ -163,7 +127,43 @@ public class AggPlanner implements Planner {
}
@Override
public HeadlessSchema findBest() {
return schema;
public String simplify(String sql) {
return optimize(sql);
}
}
public void optimize() {
if (Objects.isNull(schema.getRuntimeOptions()) || Objects.isNull(schema.getRuntimeOptions().getEnableOptimize())
|| !schema.getRuntimeOptions().getEnableOptimize()) {
return;
}
SqlNode optimizeNode = optimizeSql(SemanticNode.getSql(parserNode));
if (Objects.nonNull(optimizeNode)) {
parserNode = optimizeNode;
}
}
public String optimize(String sql) {
try {
SqlNode sqlNode = SqlParser.create(sql, Configuration.getParserConfig()).parseStmt();
if (Objects.nonNull(sqlNode)) {
return SemanticNode.getSql(SemanticNode.optimize(scope, schema, sqlNode));
}
} catch (Exception e) {
log.error("optimize error {}", e);
}
return "";
}
private SqlNode optimizeSql(String sql) {
try {
log.info("before optimize:[{}]", sql);
SqlNode sqlNode = SqlParser.create(sql, Configuration.getParserConfig()).parseStmt();
if (Objects.nonNull(sqlNode)) {
return SemanticNode.optimize(scope, schema, sqlNode);
}
} catch (Exception e) {
log.error("optimize error {}", e);
}
return null;
}
}

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.core.parser.calcite.planner;
import com.tencent.supersonic.headless.api.enums.AggOption;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.parser.calcite.schema.HeadlessSchema;
/**
* parse and generate SQL and other execute information
@@ -16,5 +15,5 @@ public interface Planner {
public String getSourceId();
public HeadlessSchema findBest();
public String simplify(String sql);
}

View File

@@ -8,7 +8,7 @@ import org.apache.commons.lang3.tuple.Triple;
@Data
@Builder
public class JoinRelation {
private Long id;
private String left;
private String right;
private String joinType;

View File

@@ -54,44 +54,61 @@ public class DataSourceNode extends SemanticNode {
}
private static void addSchema(SqlValidatorScope scope, DataSource datasource, SqlNode table) throws Exception {
Map<String, String> parseInfo = SemanticNode.getDbTable(table);
Map<String, Object> parseInfo = SemanticNode.getDbTable(table);
if (!parseInfo.isEmpty() && parseInfo.containsKey(Constants.SQL_PARSER_TABLE)) {
Set<String> dateInfo = new HashSet<>();
Set<String> dimensions = new HashSet<>();
Set<String> metrics = new HashSet<>();
String db = parseInfo.containsKey(Constants.SQL_PARSER_DB) ? parseInfo.get(Constants.SQL_PARSER_DB) : "";
String tb = parseInfo.get(Constants.SQL_PARSER_TABLE);
for (Dimension d : datasource.getDimensions()) {
List<SqlNode> identifiers = expand(SemanticNode.parse(d.getExpr(), scope), scope);
identifiers.stream().forEach(i -> dimensions.add(i.toString()));
dimensions.add(d.getName());
}
if (parseInfo.containsKey(Constants.SQL_PARSER_FIELD)) {
for (String field : parseInfo.get(Constants.SQL_PARSER_FIELD).split(",")) {
dimensions.add(field);
Map<String, Set<String>> dbTbs = (Map<String, Set<String>>) parseInfo.get(Constants.SQL_PARSER_TABLE);
Map<String, Set<String>> fields = (Map<String, Set<String>>) parseInfo.get(Constants.SQL_PARSER_FIELD);
for (Map.Entry<String, Set<String>> entry : dbTbs.entrySet()) {
for (String dbTb : entry.getValue()) {
String[] dbTable = dbTb.split("\\.");
if (Objects.nonNull(dbTable) && dbTable.length > 0) {
String tb = dbTable.length > 1 ? dbTable[1] : dbTable[0];
String db = dbTable.length > 1 ? dbTable[0] : "";
addSchemaTable(scope, datasource, db, tb,
fields.containsKey(entry.getKey()) ? fields.get(entry.getKey()) : new HashSet<>());
}
}
}
for (Identify i : datasource.getIdentifiers()) {
dimensions.add(i.getName());
}
for (Measure m : datasource.getMeasures()) {
List<SqlNode> identifiers = expand(SemanticNode.parse(m.getExpr(), scope), scope);
identifiers.stream().forEach(i -> {
if (!dimensions.contains(i.toString())) {
metrics.add(i.toString());
}
}
);
if (!dimensions.contains(m.getName())) {
metrics.add(m.getName());
}
}
SchemaBuilder.addSourceView(scope.getValidator().getCatalogReader().getRootSchema(), db,
tb, dateInfo,
dimensions, metrics);
}
}
private static void addSchemaTable(SqlValidatorScope scope, DataSource datasource, String db, String tb,
Set<String> fields)
throws Exception {
Set<String> dateInfo = new HashSet<>();
Set<String> dimensions = new HashSet<>();
Set<String> metrics = new HashSet<>();
for (Dimension d : datasource.getDimensions()) {
List<SqlNode> identifiers = expand(SemanticNode.parse(d.getExpr(), scope), scope);
identifiers.stream().forEach(i -> dimensions.add(i.toString()));
dimensions.add(d.getName());
}
for (Identify i : datasource.getIdentifiers()) {
dimensions.add(i.getName());
}
for (Measure m : datasource.getMeasures()) {
List<SqlNode> identifiers = expand(SemanticNode.parse(m.getExpr(), scope), scope);
identifiers.stream().forEach(i -> {
if (!dimensions.contains(i.toString())) {
metrics.add(i.toString());
}
}
);
if (!dimensions.contains(m.getName())) {
metrics.add(m.getName());
}
}
for (String field : fields) {
if (!metrics.contains(field) && !dimensions.contains(field)) {
dimensions.add(field);
log.info("add column {} {}", datasource.getName(), field);
}
}
SchemaBuilder.addSourceView(scope.getValidator().getCatalogReader().getRootSchema(), db,
tb, dateInfo,
dimensions, metrics);
}
public static SqlNode buildExtend(DataSource datasource, Set<String> exprList,
SqlValidatorScope scope)
throws Exception {
@@ -265,7 +282,13 @@ public class DataSourceNode extends SemanticNode {
Set<String> before = new HashSet<>();
before.add(baseDataSource.getName());
if (!CollectionUtils.isEmpty(schema.getJoinRelations())) {
for (JoinRelation joinRelation : schema.getJoinRelations()) {
Set<Long> visitJoinRelations = new HashSet<>();
List<JoinRelation> sortedJoinRelation = new ArrayList<>();
sortJoinRelation(schema.getJoinRelations(), baseDataSource.getName(), visitJoinRelations,
sortedJoinRelation);
schema.getJoinRelations().stream().filter(j -> !visitJoinRelations.contains(j.getId()))
.forEach(j -> sortedJoinRelation.add(j));
for (JoinRelation joinRelation : sortedJoinRelation) {
if (!before.contains(joinRelation.getLeft()) && !before.contains(joinRelation.getRight())) {
continue;
}
@@ -321,6 +344,21 @@ public class DataSourceNode extends SemanticNode {
return linkDataSources;
}
private static void sortJoinRelation(List<JoinRelation> joinRelations, String next, Set<Long> visited,
List<JoinRelation> sortedJoins) {
for (JoinRelation link : joinRelations) {
if (!visited.contains(link.getId())) {
if (link.getLeft().equals(next) || link.getRight().equals(next)) {
visited.add(link.getId());
sortedJoins.add(link);
sortJoinRelation(joinRelations, link.getLeft().equals(next) ? link.getRight() : link.getLeft(),
visited,
sortedJoins);
}
}
}
}
private static List<DataSource> getLinkDataSources(Set<String> baseIdentifiers,
Set<String> queryDimension,
List<String> measures,
@@ -379,4 +417,4 @@ public class DataSourceNode extends SemanticNode {
}
return Lists.newArrayList();
}
}
}

View File

@@ -3,7 +3,9 @@ package com.tencent.supersonic.headless.core.parser.calcite.sql.node;
import com.tencent.supersonic.headless.core.parser.calcite.Configuration;
import com.tencent.supersonic.headless.core.parser.calcite.s2sql.Constants;
import com.tencent.supersonic.headless.core.parser.calcite.schema.HeadlessSchema;
import com.tencent.supersonic.headless.core.parser.calcite.schema.SemanticSqlDialect;
import com.tencent.supersonic.headless.core.parser.calcite.sql.optimizer.FilterToGroupScanRule;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -15,8 +17,13 @@ import java.util.Set;
import java.util.function.UnaryOperator;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlAsOperator;
import org.apache.calcite.sql.SqlBasicCall;
@@ -45,6 +52,7 @@ import org.apache.commons.lang3.StringUtils;
/**
* model item node
*/
@Slf4j
public abstract class SemanticNode {
public static Set<SqlKind> AGGREGATION_KIND = new HashSet<>();
@@ -156,7 +164,7 @@ public abstract class SemanticNode {
return sqlNode;
}
private static void sqlVisit(SqlNode sqlNode, Map<String, String> parseInfo) {
private static void sqlVisit(SqlNode sqlNode, Map<String, Object> parseInfo) {
SqlKind kind = sqlNode.getKind();
switch (kind) {
case SELECT:
@@ -164,12 +172,21 @@ public abstract class SemanticNode {
break;
case AS:
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
sqlVisit(sqlBasicCall.getOperandList().get(0), parseInfo);
if (sqlBasicCall.getOperandList().get(0).getKind().equals(SqlKind.IDENTIFIER)) {
addTableName(sqlBasicCall.getOperandList().get(0).toString(),
sqlBasicCall.getOperandList().get(1).toString(), parseInfo);
} else {
sqlVisit(sqlBasicCall.getOperandList().get(0), parseInfo);
}
break;
case JOIN:
SqlJoin sqlJoin = (SqlJoin) sqlNode;
sqlVisit(sqlJoin.getLeft(), parseInfo);
sqlVisit(sqlJoin.getRight(), parseInfo);
SqlBasicCall condition = (SqlBasicCall) sqlJoin.getCondition();
if (Objects.nonNull(condition)) {
condition.getOperandList().stream().forEach(c -> addTagField(c.toString(), parseInfo, ""));
}
break;
case UNION:
((SqlBasicCall) sqlNode).getOperandList().forEach(node -> {
@@ -185,7 +202,7 @@ public abstract class SemanticNode {
}
}
private static void queryVisit(SqlNode select, Map<String, String> parseInfo) {
private static void queryVisit(SqlNode select, Map<String, Object> parseInfo) {
if (select == null) {
return;
}
@@ -197,7 +214,7 @@ public abstract class SemanticNode {
fromVisit(sqlSelect.getFrom(), parseInfo);
}
private static void fieldVisit(SqlNode field, Map<String, String> parseInfo, String func) {
private static void fieldVisit(SqlNode field, Map<String, Object> parseInfo, String func) {
if (field == null) {
return;
}
@@ -237,39 +254,57 @@ public abstract class SemanticNode {
}
}
private static void addTagField(String exp, Map<String, String> parseInfo, String func) {
Set<String> fields = new HashSet<>();
for (String f : exp.split("[^\\w]+")) {
if (Pattern.matches("(?i)[a-z\\d_]+", f)) {
fields.add(f);
private static void addTagField(String exp, Map<String, Object> parseInfo, String func) {
if (!parseInfo.containsKey(Constants.SQL_PARSER_FIELD)) {
parseInfo.put(Constants.SQL_PARSER_FIELD, new HashMap<>());
}
Map<String, Set<String>> fields = (Map<String, Set<String>>) parseInfo.get(Constants.SQL_PARSER_FIELD);
if (Pattern.matches("(?i)[a-z\\d_\\.]+", exp)) {
if (exp.contains(".")) {
String[] res = exp.split("\\.");
if (!fields.containsKey(res[0])) {
fields.put(res[0], new HashSet<>());
}
fields.get(res[0]).add(res[1]);
} else {
if (!fields.containsKey("")) {
fields.put("", new HashSet<>());
}
fields.get("").add(exp);
}
}
if (!fields.isEmpty()) {
parseInfo.put(Constants.SQL_PARSER_FIELD, fields.stream().collect(Collectors.joining(",")));
}
}
private static void fromVisit(SqlNode from, Map<String, String> parseInfo) {
private static void fromVisit(SqlNode from, Map<String, Object> parseInfo) {
SqlKind kind = from.getKind();
switch (kind) {
case IDENTIFIER:
SqlIdentifier sqlIdentifier = (SqlIdentifier) from;
addTableName(sqlIdentifier.toString(), parseInfo);
addTableName(sqlIdentifier.toString(), "", parseInfo);
break;
case AS:
SqlBasicCall sqlBasicCall = (SqlBasicCall) from;
SqlNode selectNode1 = sqlBasicCall.getOperandList().get(0);
if (!SqlKind.UNION.equals(selectNode1.getKind())) {
if (!SqlKind.SELECT.equals(selectNode1.getKind())) {
addTableName(selectNode1.toString(), parseInfo);
SqlNode selectNode0 = sqlBasicCall.getOperandList().get(0);
SqlNode selectNode1 = sqlBasicCall.getOperandList().get(1);
if (!SqlKind.UNION.equals(selectNode0.getKind())) {
if (!SqlKind.SELECT.equals(selectNode0.getKind())) {
addTableName(selectNode0.toString(), selectNode1.toString(), parseInfo);
}
}
sqlVisit(selectNode1, parseInfo);
sqlVisit(selectNode0, parseInfo);
break;
case JOIN:
SqlJoin sqlJoin = (SqlJoin) from;
sqlVisit(sqlJoin.getLeft(), parseInfo);
sqlVisit(sqlJoin.getRight(), parseInfo);
SqlBasicCall condition = (SqlBasicCall) sqlJoin.getCondition();
if (Objects.nonNull(condition)) {
condition.getOperandList().stream().forEach(c -> addTagField(c.toString(), parseInfo, ""));
}
break;
case SELECT:
sqlVisit(from, parseInfo);
@@ -279,27 +314,49 @@ public abstract class SemanticNode {
}
}
private static void addTableName(String exp, Map<String, String> parseInfo) {
private static void addTableName(String exp, String alias, Map<String, Object> parseInfo) {
if (exp.indexOf(" ") > 0) {
return;
}
if (exp.indexOf("_") > 0) {
if (exp.split("_").length > 1) {
String[] dbTb = exp.split("\\.");
if (Objects.nonNull(dbTb) && dbTb.length > 0) {
parseInfo.put(Constants.SQL_PARSER_TABLE, dbTb.length > 1 ? dbTb[1] : dbTb[0]);
parseInfo.put(Constants.SQL_PARSER_DB, dbTb.length > 1 ? dbTb[0] : "");
}
}
if (!parseInfo.containsKey(Constants.SQL_PARSER_TABLE)) {
parseInfo.put(Constants.SQL_PARSER_TABLE, new HashMap<>());
}
Map<String, Set<String>> dbTbs = (Map<String, Set<String>>) parseInfo.get(Constants.SQL_PARSER_TABLE);
if (!dbTbs.containsKey(alias)) {
dbTbs.put(alias, new HashSet<>());
}
dbTbs.get(alias).add(exp);
}
public static Map<String, String> getDbTable(SqlNode sqlNode) {
Map<String, String> parseInfo = new HashMap<>();
public static Map<String, Object> getDbTable(SqlNode sqlNode) {
Map<String, Object> parseInfo = new HashMap<>();
sqlVisit(sqlNode, parseInfo);
return parseInfo;
}
public static SqlNode optimize(SqlValidatorScope scope, HeadlessSchema schema, SqlNode sqlNode) {
try {
HepProgramBuilder hepProgramBuilder = new HepProgramBuilder();
hepProgramBuilder.addRuleInstance(new FilterToGroupScanRule(FilterToGroupScanRule.DEFAULT, schema));
RelOptPlanner relOptPlanner = new HepPlanner(hepProgramBuilder.build());
RelToSqlConverter converter = new RelToSqlConverter(SemanticSqlDialect.DEFAULT);
SqlValidator sqlValidator = Configuration.getSqlValidator(
scope.getValidator().getCatalogReader().getRootSchema());
SqlToRelConverter sqlToRelConverter = Configuration.getSqlToRelConverter(scope, sqlValidator,
relOptPlanner);
RelNode sqlRel = sqlToRelConverter.convertQuery(
sqlValidator.validate(sqlNode), false, true).rel;
log.debug("RelNode optimize {}", SemanticNode.getSql(converter.visitRoot(sqlRel).asStatement()));
relOptPlanner.setRoot(sqlRel);
RelNode relNode = relOptPlanner.findBestExp();
return converter.visitRoot(relNode).asStatement();
} catch (Exception e) {
log.error("optimize error {}", e);
}
return null;
}
public static RelNode getRelNode(CalciteSchema rootSchema, SqlToRelConverter sqlToRelConverter, String sql)
throws SqlParseException {
SqlValidator sqlValidator = Configuration.getSqlValidator(rootSchema);

View File

@@ -26,6 +26,9 @@ public class QueryStatement {
private List<ImmutablePair<String, String>> timeRanges;
private Boolean enableOptimize = true;
private Triple<String, String, String> minMaxTime;
private String viewSql = "";
private String viewAlias = "";
private HeadlessModel headlessModel;

View File

@@ -227,6 +227,7 @@ public class HeadlessSchemaManager {
conditions.add(Triple.of(rr.getLeftField(), rr.getOperator().getValue(), rr.getRightField()));
}
});
joinRelation.setId(r.getId());
joinRelation.setJoinCondition(conditions);
joinRelations.add(joinRelation);
}
@@ -337,4 +338,4 @@ public class HeadlessSchemaManager {
}
}
}
}

View File

@@ -25,14 +25,6 @@ import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.utils.SqlGenerateUtils;
import com.tencent.supersonic.headless.server.service.Catalog;
import com.tencent.supersonic.headless.server.service.HeadlessQueryEngine;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@@ -41,10 +33,22 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Component
@Slf4j
public class QueryReqConverter {
@Value("${query.sql.limitWrapper:true}")
private Boolean limitWrapper;
@Autowired
private HeadlessQueryEngine headlessQueryEngine;
@Autowired
@@ -57,7 +61,7 @@ public class QueryReqConverter {
private Catalog catalog;
public QueryStatement convert(QueryS2SQLReq queryS2SQLReq,
List<ModelSchemaResp> modelSchemaResps) throws Exception {
List<ModelSchemaResp> modelSchemaResps) throws Exception {
if (CollectionUtils.isEmpty(modelSchemaResps)) {
return new QueryStatement();
@@ -127,7 +131,8 @@ public class QueryReqConverter {
queryStatement.setMinMaxTime(queryStructUtils.getBeginEndTime(queryStructReq));
queryStatement.setModelIds(queryS2SQLReq.getModelIds());
queryStatement = headlessQueryEngine.plan(queryStatement);
queryStatement.setSql(String.format(SqlExecuteReq.LIMIT_WRAPPER, queryStatement.getSql()));
queryStatement.setSql(limitWrapper ? String.format(SqlExecuteReq.LIMIT_WRAPPER, queryStatement.getSql())
: queryStatement.getSql());
return queryStatement;
}