(improvement)(headless) parser add model field (#650)

This commit is contained in:
jipeli
2024-01-18 21:52:36 +08:00
committed by GitHub
parent 90f9da162e
commit b019f4d9bb
16 changed files with 218 additions and 41 deletions

View File

@@ -1,10 +1,17 @@
package com.tencent.supersonic.headless.core.manager;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.FieldParam;
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByFieldParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMetricParams;
import com.tencent.supersonic.headless.api.pojo.MetricParam;
import com.tencent.supersonic.headless.api.response.MetricResp;
import com.tencent.supersonic.headless.core.pojo.yaml.FieldParamYamlTpl;
import com.tencent.supersonic.headless.core.pojo.yaml.MeasureYamlTpl;
import com.tencent.supersonic.headless.core.pojo.yaml.MetricParamYamlTpl;
import com.tencent.supersonic.headless.core.pojo.yaml.MetricTypeParamsYamlTpl;
import com.tencent.supersonic.headless.core.pojo.yaml.MetricYamlTpl;
import lombok.extern.slf4j.Slf4j;
@@ -37,12 +44,26 @@ public class MetricYamlManager {
BeanUtils.copyProperties(metric, metricYamlTpl);
metricYamlTpl.setName(metric.getBizName());
metricYamlTpl.setOwners(Lists.newArrayList(metric.getCreatedBy()));
MetricDefineByMeasureParams metricDefineParams = metric.getTypeParams();
MetricTypeParamsYamlTpl metricTypeParamsYamlTpl = new MetricTypeParamsYamlTpl();
metricTypeParamsYamlTpl.setExpr(metricDefineParams.getExpr());
List<MeasureParam> measures = metricDefineParams.getMeasures();
metricTypeParamsYamlTpl.setMeasures(
measures.stream().map(MetricYamlManager::convert).collect(Collectors.toList()));
if (MetricDefineType.MEASURE.equals(metric.getMetricDefineType())) {
MetricDefineByMeasureParams metricDefineParams = metric.getTypeParams();
metricTypeParamsYamlTpl.setExpr(metricDefineParams.getExpr());
List<MeasureParam> measures = metricDefineParams.getMeasures();
metricTypeParamsYamlTpl.setMeasures(
measures.stream().map(MetricYamlManager::convert).collect(Collectors.toList()));
} else if (MetricDefineType.FIELD.equals(metric.getMetricDefineType())) {
MetricDefineByFieldParams metricDefineParams = metric.getMetricDefineByFieldParams();
metricTypeParamsYamlTpl.setExpr(metricDefineParams.getExpr());
List<FieldParam> fields = metricDefineParams.getFields();
metricTypeParamsYamlTpl.setFields(
fields.stream().map(MetricYamlManager::convert).collect(Collectors.toList()));
} else if (MetricDefineType.METRIC.equals(metric.getMetricDefineType())) {
MetricDefineByMetricParams metricDefineByMetricParams = metric.getMetricDefineByMetricParams();
metricTypeParamsYamlTpl.setExpr(metricDefineByMetricParams.getExpr());
List<MetricParam> metrics = metricDefineByMetricParams.getMetrics();
metricTypeParamsYamlTpl.setMetrics(
metrics.stream().map(MetricYamlManager::convert).collect(Collectors.toList()));
}
metricYamlTpl.setTypeParams(metricTypeParamsYamlTpl);
return metricYamlTpl;
}
@@ -55,4 +76,17 @@ public class MetricYamlManager {
return measureYamlTpl;
}
public static FieldParamYamlTpl convert(FieldParam fieldParam) {
FieldParamYamlTpl fieldParamYamlTpl = new FieldParamYamlTpl();
fieldParamYamlTpl.setFieldName(fieldParam.getFieldName());
return fieldParamYamlTpl;
}
public static MetricParamYamlTpl convert(MetricParam metricParam) {
MetricParamYamlTpl metricParamYamlTpl = new MetricParamYamlTpl();
metricParamYamlTpl.setBizName(metricParam.getBizName());
metricParamYamlTpl.setId(metricParam.getId());
return metricParamYamlTpl;
}
}

View File

@@ -53,6 +53,7 @@ public class ModelYamlManager {
} else {
dataModelYamlTpl.setTableQuery(modelDetail.getTableQuery());
}
dataModelYamlTpl.setFields(modelResp.getModelDetail().getFields());
return dataModelYamlTpl;
}

View File

@@ -7,6 +7,8 @@ import lombok.Data;
public class MetricTypeParams {
private List<Measure> measures;
private List<Measure> metrics;
private List<Measure> fields;
private String expr;

View File

@@ -66,13 +66,13 @@ public class SchemaBuilder {
String db = dbSrc.toLowerCase();
DataSourceTable.Builder builder = DataSourceTable.newBuilder(tb);
for (String date : dates) {
builder.addField(date.toLowerCase(), SqlTypeName.VARCHAR);
builder.addField(date, SqlTypeName.VARCHAR);
}
for (String dim : dimensions) {
builder.addField(dim.toLowerCase(), SqlTypeName.VARCHAR);
builder.addField(dim, SqlTypeName.VARCHAR);
}
for (String metric : metrics) {
builder.addField(metric.toLowerCase(), SqlTypeName.BIGINT);
builder.addField(metric, SqlTypeName.BIGINT);
}
DataSourceTable srcTable = builder
.withRowCount(1)

View File

@@ -1,11 +1,15 @@
package com.tencent.supersonic.headless.core.parser.calcite.sql.node;
import java.util.Objects;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.validate.SqlValidatorScope;
public class AggFunctionNode extends SemanticNode {
public static SqlNode build(String agg, String name, SqlValidatorScope scope) throws Exception {
if (Objects.isNull(agg) || agg.isEmpty()) {
return parse(name, scope);
}
if (AggFunction.COUNT_DISTINCT.name().equalsIgnoreCase(agg)) {
return parse(AggFunction.COUNT.name() + " ( " + AggFunction.DISTINCT.name() + " " + name + " ) ", scope);
}

View File

@@ -11,8 +11,8 @@ import com.tencent.supersonic.headless.core.parser.calcite.s2sql.Dimension;
import com.tencent.supersonic.headless.core.parser.calcite.s2sql.Identify;
import com.tencent.supersonic.headless.core.parser.calcite.s2sql.JoinRelation;
import com.tencent.supersonic.headless.core.parser.calcite.s2sql.Measure;
import com.tencent.supersonic.headless.core.parser.calcite.schema.SemanticSchema;
import com.tencent.supersonic.headless.core.parser.calcite.schema.SchemaBuilder;
import com.tencent.supersonic.headless.core.parser.calcite.schema.SemanticSchema;
import com.tencent.supersonic.headless.core.parser.calcite.sql.node.extend.LateralViewExplodeNode;
import java.util.ArrayList;
import java.util.Arrays;
@@ -72,7 +72,10 @@ public class DataSourceNode extends SemanticNode {
String tb = dbTable.length > 1 ? dbTable[1] : dbTable[0];
String db = dbTable.length > 1 ? dbTable[0] : "";
addSchemaTable(scope, datasource, db, tb,
fields.containsKey(entry.getKey()) ? fields.get(entry.getKey()) : new HashSet<>());
fields.containsKey(entry.getKey()) ? fields.get(entry.getKey())
: dbTbs.size() == 1 && fields.size() == 1 && fields.containsKey("")
? fields.get("")
: new HashSet<>());
}
}
}

View File

@@ -47,6 +47,7 @@ import org.apache.calcite.sql.pretty.SqlPrettyWriter;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
/**
@@ -57,6 +58,7 @@ public abstract class SemanticNode {
public static Set<SqlKind> AGGREGATION_KIND = new HashSet<>();
public static Set<String> AGGREGATION_FUNC = new HashSet<>();
public static List<String> groupHints = new ArrayList<>(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"));
static {
AGGREGATION_KIND.add(SqlKind.AVG);
@@ -212,6 +214,59 @@ public abstract class SemanticNode {
fieldVisit(list, parseInfo, "");
});
fromVisit(sqlSelect.getFrom(), parseInfo);
if (sqlSelect.hasWhere()) {
whereVisit((SqlBasicCall) sqlSelect.getWhere(), parseInfo);
}
if (sqlSelect.hasOrderBy()) {
fieldVisit(sqlSelect.getOrderList(), parseInfo, "");
}
SqlNodeList group = sqlSelect.getGroup();
if (group != null) {
group.forEach(groupField -> {
if (groupHints.contains(groupField.toString())) {
int groupIdx = Integer.valueOf(groupField.toString()) - 1;
if (selectList.getList().size() > groupIdx) {
fieldVisit(selectList.get(groupIdx), parseInfo, "");
}
} else {
fieldVisit(groupField, parseInfo, "");
}
});
}
}
private static void whereVisit(SqlBasicCall where, Map<String, Object> parseInfo) {
if (where == null) {
return;
}
if (where.operandCount() == 2 && where.operand(0).getKind().equals(SqlKind.IDENTIFIER)
&& where.operand(1).getKind().equals(SqlKind.LITERAL)) {
fieldVisit(where.operand(0), parseInfo, "");
return;
}
// 子查询
if (where.operandCount() == 2
&& (where.operand(0).getKind().equals(SqlKind.IDENTIFIER)
&& (where.operand(1).getKind().equals(SqlKind.SELECT)
|| where.operand(1).getKind().equals(SqlKind.ORDER_BY)))
) {
fieldVisit(where.operand(0), parseInfo, "");
sqlVisit((SqlNode) (where.operand(1)), parseInfo);
return;
}
if (CollectionUtils.isNotEmpty(where.getOperandList()) && where.operand(0).getKind()
.equals(SqlKind.IDENTIFIER)) {
fieldVisit(where.operand(0), parseInfo, "");
}
if (where.operandCount() >= 2 && where.operand(1).getKind().equals(SqlKind.IDENTIFIER)) {
fieldVisit(where.operand(1), parseInfo, "");
}
if (CollectionUtils.isNotEmpty(where.getOperandList()) && where.operand(0) instanceof SqlBasicCall) {
whereVisit(where.operand(0), parseInfo);
}
if (where.operandCount() >= 2 && where.operand(1) instanceof SqlBasicCall) {
whereVisit(where.operand(1), parseInfo);
}
}
private static void fieldVisit(SqlNode field, Map<String, Object> parseInfo, String func) {

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.core.pojo.yaml;
import com.tencent.supersonic.headless.api.enums.ModelSourceType;
import com.tencent.supersonic.headless.api.pojo.Field;
import lombok.Data;
import java.util.List;
@@ -27,6 +28,8 @@ public class DataModelYamlTpl {
private List<MeasureYamlTpl> measures;
private List<Field> fields;
private ModelSourceType modelSourceTypeEnum;

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.headless.core.pojo.yaml;
import lombok.Data;
@Data
public class FieldParamYamlTpl {
private String fieldName;
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.headless.core.pojo.yaml;
import lombok.Data;
@Data
public class MetricParamYamlTpl {
private Long id;
private String bizName;
}

View File

@@ -9,7 +9,10 @@ public class MetricTypeParamsYamlTpl {
private List<MeasureYamlTpl> measures;
private List<MetricParamYamlTpl> metrics;
private List<FieldParamYamlTpl> fields;
private String expr;
}

View File

@@ -19,7 +19,6 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.headless.api.enums.EngineType;
import com.tencent.supersonic.headless.api.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.enums.MetricType;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.request.QueryStructReq;
import com.tencent.supersonic.headless.api.response.DimensionResp;
@@ -283,19 +282,14 @@ public class SqlGenerateUtils {
Optional<MetricResp> metricItem = metricResps.stream()
.filter(m -> m.getBizName().equalsIgnoreCase(field)).findFirst();
if (metricItem.isPresent()) {
if (MetricType.isDerived(metricItem.get().getMetricDefineType(),
metricItem.get().getTypeParams())) {
if (visitedMetric.contains(field)) {
break;
}
replace.put(field,
generateDerivedMetric(metricResps, allFields, allMeasures, dimensionResps,
getExpr(metricItem.get()), metricItem.get().getMetricDefineType(),
visitedMetric, measures, dimensions));
visitedMetric.add(field);
} else {
replace.put(field, getExpr(metricItem.get()));
if (visitedMetric.contains(field)) {
break;
}
replace.put(field,
generateDerivedMetric(metricResps, allFields, allMeasures, dimensionResps,
getExpr(metricItem.get()), metricItem.get().getMetricDefineType(),
visitedMetric, measures, dimensions));
visitedMetric.add(field);
}
break;
case MEASURE:
@@ -321,7 +315,9 @@ public class SqlGenerateUtils {
}
}
if (!CollectionUtils.isEmpty(replace)) {
return SqlParserReplaceHelper.replaceExpression(expression, replace);
String expr = SqlParserReplaceHelper.replaceExpression(expression, replace);
log.info("derived measure {}->{}", expression, expr);
return expr;
}
}
return expression;