[improvement][headless]Clean code logic of headless translator.

This commit is contained in:
jerryjzhang
2024-11-27 11:29:29 +08:00
parent 7bf1ba09c5
commit dad065d0ba
10 changed files with 184 additions and 276 deletions

View File

@@ -50,7 +50,7 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
} }
} catch (Exception e) { } catch (Exception e) {
queryStatement.setErrMsg(e.getMessage()); queryStatement.setErrMsg(e.getMessage());
log.error("Failed to translate semantic query [{}]", e.getMessage(), e); log.error("Failed to translate query [{}]", e.getMessage(), e);
} }
} }

View File

@@ -25,8 +25,4 @@ public class Ontology {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public Map<Long, DataModel> getModelMap() {
return dataModelMap.values().stream()
.collect(Collectors.toMap(DataModel::getId, dataSource -> dataSource));
}
} }

View File

@@ -52,7 +52,7 @@ public class SqlBuilder {
// find relevant data models // find relevant data models
scope = SchemaBuilder.getScope(schema); scope = SchemaBuilder.getScope(schema);
List<DataModel> dataModels = List<DataModel> dataModels =
DataModelNode.getRelatedDataModels(scope, schema, ontologyQueryParam); DataModelNode.getQueryDataModels(scope, schema, ontologyQueryParam);
if (dataModels == null || dataModels.isEmpty()) { if (dataModels == null || dataModels.isEmpty()) {
throw new Exception("data model not found"); throw new Exception("data model not found");
} }
@@ -98,20 +98,6 @@ public class SqlBuilder {
return SemanticNode.getSql(parserNode, engineType); return SemanticNode.getSql(parserNode, engineType);
} }
private String rewrite(String sql, EngineType engineType) {
try {
SqlNode sqlNode =
SqlParser.create(sql, Configuration.getParserConfig(engineType)).parseStmt();
if (Objects.nonNull(sqlNode)) {
return SemanticNode.getSql(
SemanticNode.optimize(scope, schema, sqlNode, engineType), engineType);
}
} catch (Exception e) {
log.error("optimize error {}", e.toString());
}
return "";
}
private void optimizeParseNode(EngineType engineType) { private void optimizeParseNode(EngineType engineType) {
if (Objects.isNull(schema.getRuntimeOptions()) if (Objects.isNull(schema.getRuntimeOptions())
|| Objects.isNull(schema.getRuntimeOptions().getEnableOptimize()) || Objects.isNull(schema.getRuntimeOptions().getEnableOptimize())

View File

@@ -4,36 +4,17 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.common.calcite.Configuration; import com.tencent.supersonic.common.calcite.Configuration;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.EngineType; import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Constants; import com.tencent.supersonic.headless.core.translator.calcite.s2sql.*;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.DataModel;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Dimension;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Identify;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.JoinRelation;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Measure;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.OntologyQueryParam;
import com.tencent.supersonic.headless.core.translator.calcite.sql.S2CalciteSchema; import com.tencent.supersonic.headless.core.translator.calcite.sql.S2CalciteSchema;
import com.tencent.supersonic.headless.core.translator.calcite.sql.SchemaBuilder; import com.tencent.supersonic.headless.core.translator.calcite.sql.SchemaBuilder;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.*;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlUserDefinedTypeNameSpec;
import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.*;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Slf4j @Slf4j
@@ -53,7 +34,7 @@ public class DataModelNode extends SemanticNode {
} }
} }
if (sqlTable.isEmpty()) { if (sqlTable.isEmpty()) {
throw new Exception("DatasourceNode build error [tableSqlNode not found]"); throw new Exception("DataModelNode build error [tableSqlNode not found]");
} }
SqlNode source = getTable(sqlTable, scope, EngineType.fromString(dataModel.getType())); SqlNode source = getTable(sqlTable, scope, EngineType.fromString(dataModel.getType()));
addSchema(scope, dataModel, sqlTable); addSchema(scope, dataModel, sqlTable);
@@ -149,166 +130,171 @@ public class DataModelNode extends SemanticNode {
return dataModelList.stream().map(d -> d.getName()).collect(Collectors.joining("_")); return dataModelList.stream().map(d -> d.getName()).collect(Collectors.joining("_"));
} }
public static void getQueryDimensionMeasure(S2CalciteSchema schema, public static void getQueryDimensionMeasure(Ontology ontology, OntologyQueryParam queryParam,
OntologyQueryParam queryParam, Set<String> queryDimensions, Set<String> queryMeasures) { Set<String> queryDimensions, Set<String> queryMeasures) {
queryDimensions.addAll(queryParam.getDimensions().stream() queryDimensions.addAll(queryParam.getDimensions().stream()
.map(d -> d.contains(Constants.DIMENSION_IDENTIFY) .map(d -> d.contains(Constants.DIMENSION_IDENTIFY)
? d.split(Constants.DIMENSION_IDENTIFY)[1] ? d.split(Constants.DIMENSION_IDENTIFY)[1]
: d) : d)
.collect(Collectors.toSet())); .collect(Collectors.toSet()));
Set<String> schemaMetricName = Set<String> schemaMetricName =
schema.getMetrics().stream().map(m -> m.getName()).collect(Collectors.toSet()); ontology.getMetrics().stream().map(m -> m.getName()).collect(Collectors.toSet());
schema.getMetrics().stream().filter(m -> queryParam.getMetrics().contains(m.getName())) ontology.getMetrics().stream().filter(m -> queryParam.getMetrics().contains(m.getName()))
.forEach(m -> m.getMetricTypeParams().getMeasures().stream() .forEach(m -> m.getMetricTypeParams().getMeasures().stream()
.forEach(mm -> queryMeasures.add(mm.getName()))); .forEach(mm -> queryMeasures.add(mm.getName())));
queryParam.getMetrics().stream().filter(m -> !schemaMetricName.contains(m)) queryParam.getMetrics().stream().filter(m -> !schemaMetricName.contains(m))
.forEach(m -> queryMeasures.add(m)); .forEach(m -> queryMeasures.add(m));
} }
public static void mergeQueryFilterDimensionMeasure(S2CalciteSchema schema, public static void mergeQueryFilterDimensionMeasure(Ontology ontology,
OntologyQueryParam metricCommand, Set<String> queryDimension, Set<String> measures, OntologyQueryParam queryParam, Set<String> dimensions, Set<String> measures,
SqlValidatorScope scope) throws Exception { SqlValidatorScope scope) throws Exception {
EngineType engineType = schema.getOntology().getDatabase().getType(); EngineType engineType = ontology.getDatabase().getType();
if (Objects.nonNull(metricCommand.getWhere()) && !metricCommand.getWhere().isEmpty()) { if (Objects.nonNull(queryParam.getWhere()) && !queryParam.getWhere().isEmpty()) {
Set<String> filterConditions = new HashSet<>(); Set<String> filterConditions = new HashSet<>();
FilterNode.getFilterField(parse(metricCommand.getWhere(), scope, engineType), FilterNode.getFilterField(parse(queryParam.getWhere(), scope, engineType),
filterConditions); filterConditions);
Set<String> queryMeasures = new HashSet<>(measures); Set<String> queryMeasures = new HashSet<>(measures);
Set<String> schemaMetricName = Set<String> schemaMetricName = ontology.getMetrics().stream().map(m -> m.getName())
schema.getMetrics().stream().map(m -> m.getName()).collect(Collectors.toSet()); .collect(Collectors.toSet());
for (String filterCondition : filterConditions) { for (String filterCondition : filterConditions) {
if (schemaMetricName.contains(filterCondition)) { if (schemaMetricName.contains(filterCondition)) {
schema.getMetrics().stream() ontology.getMetrics().stream()
.filter(m -> m.getName().equalsIgnoreCase(filterCondition)) .filter(m -> m.getName().equalsIgnoreCase(filterCondition))
.forEach(m -> m.getMetricTypeParams().getMeasures().stream() .forEach(m -> m.getMetricTypeParams().getMeasures().stream()
.forEach(mm -> queryMeasures.add(mm.getName()))); .forEach(mm -> queryMeasures.add(mm.getName())));
continue; continue;
} }
queryDimension.add(filterCondition); dimensions.add(filterCondition);
} }
measures.clear(); measures.clear();
measures.addAll(queryMeasures); measures.addAll(queryMeasures);
} }
} }
public static List<DataModel> getRelatedDataModels(SqlValidatorScope scope, public static List<DataModel> getQueryDataModels(SqlValidatorScope scope,
S2CalciteSchema schema, OntologyQueryParam queryParam) throws Exception { S2CalciteSchema schema, OntologyQueryParam queryParam) throws Exception {
List<DataModel> dataModels = new ArrayList<>(); Ontology ontology = schema.getOntology();
// get query measures and dimensions
// check by metric
Set<String> queryMeasures = new HashSet<>(); Set<String> queryMeasures = new HashSet<>();
Set<String> queryDimensions = new HashSet<>(); Set<String> queryDimensions = new HashSet<>();
getQueryDimensionMeasure(schema, queryParam, queryDimensions, queryMeasures); getQueryDimensionMeasure(ontology, queryParam, queryDimensions, queryMeasures);
DataModel baseDataModel = null; mergeQueryFilterDimensionMeasure(ontology, queryParam, queryDimensions, queryMeasures,
// one , match measure count scope);
Map<String, Integer> dataSourceMeasures = new HashMap<>();
for (Map.Entry<String, DataModel> entry : schema.getDataModels().entrySet()) { // first, find the base model
Set<String> sourceMeasure = entry.getValue().getMeasures().stream() DataModel baseDataModel = findBaseModel(ontology, queryMeasures, queryDimensions);
.map(mm -> mm.getName()).collect(Collectors.toSet()); if (Objects.isNull(baseDataModel)) {
sourceMeasure.retainAll(queryMeasures); throw new RuntimeException(
dataSourceMeasures.put(entry.getKey(), sourceMeasure.size()); String.format("could not find matching dataModel, dimensions:%s, measures:%s",
}
log.info("metrics: [{}]", dataSourceMeasures);
Optional<Map.Entry<String, Integer>> base = dataSourceMeasures.entrySet().stream()
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst();
if (base.isPresent()) {
baseDataModel = schema.getDataModels().get(base.get().getKey());
dataModels.add(baseDataModel);
}
// second , check match all dimension and metric
if (baseDataModel != null) {
Set<String> filterMeasure = new HashSet<>();
Set<String> sourceMeasure = baseDataModel.getMeasures().stream().map(mm -> mm.getName())
.collect(Collectors.toSet());
Set<String> dimension = baseDataModel.getDimensions().stream().map(dd -> dd.getName())
.collect(Collectors.toSet());
baseDataModel.getIdentifiers().stream().forEach(i -> dimension.add(i.getName()));
if (schema.getDimensions().containsKey(baseDataModel.getName())) {
schema.getDimensions().get(baseDataModel.getName()).stream()
.forEach(d -> dimension.add(d.getName()));
}
filterMeasure.addAll(sourceMeasure);
filterMeasure.addAll(dimension);
EngineType engineType = schema.getOntology().getDatabase().getType();
mergeQueryFilterDimensionMeasure(schema, queryParam, queryDimensions, queryMeasures,
scope);
boolean isAllMatch = checkMatch(sourceMeasure, queryDimensions, queryMeasures,
dimension, queryParam, scope, engineType);
if (isAllMatch) {
log.debug("baseDataModel match all ");
return dataModels;
}
// find all dataSource has the same identifiers
List<DataModel> linkDataModels = getLinkDataSourcesByJoinRelation(queryDimensions,
queryMeasures, baseDataModel, schema);
if (CollectionUtils.isEmpty(linkDataModels)) {
log.debug("baseDataModel get by identifiers ");
Set<String> baseIdentifiers = baseDataModel.getIdentifiers().stream()
.map(i -> i.getName()).collect(Collectors.toSet());
if (baseIdentifiers.isEmpty()) {
throw new Exception(
"datasource error : " + baseDataModel.getName() + " miss identifier");
}
linkDataModels = getLinkDataSources(baseIdentifiers, queryDimensions, queryMeasures,
baseDataModel, schema);
if (linkDataModels.isEmpty()) {
throw new Exception(String.format(
"not find the match datasource : dimension[%s],measure[%s]",
queryDimensions, queryMeasures)); queryDimensions, queryMeasures));
} }
} // if the base model matches all queried measures and dimensions, just return
log.debug("linkDataModels {}", linkDataModels); if (checkMatch(baseDataModel, queryMeasures, queryDimensions)) {
return linkDataModels; log.debug("baseDataModel match all measures and dimensions");
return Collections.singletonList(baseDataModel);
} }
return dataModels; // second, traverse the ontology to find other related dataModels
List<DataModel> relatedDataModels = findRelatedModelsByRelation(ontology, baseDataModel,
queryDimensions, queryMeasures);
if (CollectionUtils.isEmpty(relatedDataModels)) {
relatedDataModels = findRelatedModelsByIdentifier(ontology, baseDataModel,
queryDimensions, queryMeasures);
}
if (CollectionUtils.isEmpty(relatedDataModels)) {
relatedDataModels = Collections.singletonList(baseDataModel);
}
log.debug("relatedDataModels {}", relatedDataModels);
return relatedDataModels;
} }
private static boolean checkMatch(Set<String> sourceMeasure, Set<String> queryDimension, private static DataModel findBaseModel(Ontology ontology, Set<String> queryMeasures,
Set<String> measures, Set<String> dimension, OntologyQueryParam metricCommand, Set<String> queryDimensions) {
SqlValidatorScope scope, EngineType engineType) throws Exception { DataModel dataModel = null;
boolean isAllMatch = true; // first, try to find the model with the most matching measures
sourceMeasure.retainAll(measures); Map<String, Integer> dataModelMeasuresCount = new HashMap<>();
if (sourceMeasure.size() < measures.size()) { for (Map.Entry<String, DataModel> entry : ontology.getDataModelMap().entrySet()) {
log.info("baseDataSource measures not match all measure"); Set<String> sourceMeasure = entry.getValue().getMeasures().stream()
// check dimension again .map(Measure::getName).collect(Collectors.toSet());
Set<String> dimensionMeasures = new HashSet<>(); sourceMeasure.retainAll(queryMeasures);
dimensionMeasures.addAll(dimension); dataModelMeasuresCount.put(entry.getKey(), sourceMeasure.size());
dimensionMeasures.retainAll(measures); }
if (sourceMeasure.size() + dimensionMeasures.size() < measures.size()) { log.info("dataModelMeasureCount: [{}]", dataModelMeasuresCount);
log.info("baseDataSource not match all measure"); Optional<Map.Entry<String, Integer>> base =
isAllMatch = false; dataModelMeasuresCount.entrySet().stream().filter(e -> e.getValue() > 0)
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst();
if (base.isPresent()) {
dataModel = ontology.getDataModelMap().get(base.get().getKey());
} else {
// second, try to find the model with the most matching dimensions
Map<String, Integer> dataModelDimCount = new HashMap<>();
for (Map.Entry<String, List<Dimension>> entry : ontology.getDimensionMap().entrySet()) {
Set<String> modelDimensions = entry.getValue().stream().map(Dimension::getName)
.collect(Collectors.toSet());
modelDimensions.retainAll(queryDimensions);
dataModelDimCount.put(entry.getKey(), modelDimensions.size());
}
log.info("dataModelDimCount: [{}]", dataModelDimCount);
base = dataModelDimCount.entrySet().stream().filter(e -> e.getValue() > 0)
.sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst();
if (base.isPresent()) {
dataModel = ontology.getDataModelMap().get(base.get().getKey());
} }
} }
measures.removeAll(sourceMeasure);
dimension.retainAll(queryDimension); return dataModel;
if (dimension.size() < queryDimension.size()) { }
log.debug("baseDataSource not match all dimension");
private static boolean checkMatch(DataModel baseDataModel, Set<String> queryMeasures,
Set<String> queryDimension) {
boolean isAllMatch = true;
Set<String> baseMeasures = baseDataModel.getMeasures().stream().map(Measure::getName)
.collect(Collectors.toSet());
Set<String> baseDimensions = baseDataModel.getDimensions().stream().map(Dimension::getName)
.collect(Collectors.toSet());
baseDataModel.getIdentifiers().stream().forEach(i -> baseDimensions.add(i.getName()));
baseMeasures.retainAll(queryMeasures);
if (baseMeasures.size() < queryMeasures.size()) {
// check dimension again
Set<String> dimensionMeasures = new HashSet<>();
dimensionMeasures.addAll(baseDimensions);
dimensionMeasures.retainAll(queryMeasures);
if (baseMeasures.size() + dimensionMeasures.size() < queryMeasures.size()) {
log.info("baseDataModel not match all measures");
isAllMatch = false;
}
queryMeasures.removeAll(dimensionMeasures);
}
queryMeasures.removeAll(baseMeasures);
baseDimensions.retainAll(queryDimension);
if (baseDimensions.size() < queryDimension.size()) {
log.debug("baseDataModel not match all dimensions");
isAllMatch = false; isAllMatch = false;
} }
queryDimension.removeAll(dimension); queryDimension.removeAll(baseDimensions);
if (metricCommand.getWhere() != null && !metricCommand.getWhere().isEmpty()) {
Set<String> whereFields = new HashSet<>();
SqlNode sqlNode = parse(metricCommand.getWhere(), scope, engineType);
FilterNode.getFilterField(sqlNode, whereFields);
}
return isAllMatch; return isAllMatch;
} }
private static List<DataModel> getLinkDataSourcesByJoinRelation(Set<String> queryDimension, private static List<DataModel> findRelatedModelsByRelation(Ontology ontology,
Set<String> measures, DataModel baseDataModel, S2CalciteSchema schema) { DataModel baseDataModel, Set<String> queryDimensions, Set<String> queryMeasures) {
Set<String> linkDataSourceName = new HashSet<>(); Set<String> joinDataModelNames = new HashSet<>();
List<DataModel> linkDataModels = new ArrayList<>(); List<DataModel> joinDataModels = new ArrayList<>();
Set<String> before = new HashSet<>(); Set<String> before = new HashSet<>();
before.add(baseDataModel.getName()); before.add(baseDataModel.getName());
if (!CollectionUtils.isEmpty(schema.getJoinRelations())) {
if (!CollectionUtils.isEmpty(ontology.getJoinRelations())) {
Set<Long> visitJoinRelations = new HashSet<>(); Set<Long> visitJoinRelations = new HashSet<>();
List<JoinRelation> sortedJoinRelation = new ArrayList<>(); List<JoinRelation> sortedJoinRelation = new ArrayList<>();
sortJoinRelation(schema.getJoinRelations(), baseDataModel.getName(), visitJoinRelations, sortJoinRelation(ontology.getJoinRelations(), baseDataModel.getName(),
sortedJoinRelation); visitJoinRelations, sortedJoinRelation);
schema.getJoinRelations().stream().filter(j -> !visitJoinRelations.contains(j.getId())) ontology.getJoinRelations().stream()
.filter(j -> !visitJoinRelations.contains(j.getId()))
.forEach(j -> sortedJoinRelation.add(j)); .forEach(j -> sortedJoinRelation.add(j));
for (JoinRelation joinRelation : sortedJoinRelation) { for (JoinRelation joinRelation : sortedJoinRelation) {
if (!before.contains(joinRelation.getLeft()) if (!before.contains(joinRelation.getLeft())
@@ -317,53 +303,54 @@ public class DataModelNode extends SemanticNode {
} }
boolean isMatch = false; boolean isMatch = false;
boolean isRight = before.contains(joinRelation.getLeft()); boolean isRight = before.contains(joinRelation.getLeft());
DataModel other = isRight ? schema.getDataModels().get(joinRelation.getRight()) DataModel other = isRight ? ontology.getDataModelMap().get(joinRelation.getRight())
: schema.getDataModels().get(joinRelation.getLeft()); : ontology.getDataModelMap().get(joinRelation.getLeft());
if (!queryDimension.isEmpty()) { if (!queryDimensions.isEmpty()) {
Set<String> linkDimension = other.getDimensions().stream() Set<String> linkDimension = other.getDimensions().stream()
.map(dd -> dd.getName()).collect(Collectors.toSet()); .map(dd -> dd.getName()).collect(Collectors.toSet());
other.getIdentifiers().stream().forEach(i -> linkDimension.add(i.getName())); other.getIdentifiers().stream().forEach(i -> linkDimension.add(i.getName()));
linkDimension.retainAll(queryDimension); linkDimension.retainAll(queryDimensions);
if (!linkDimension.isEmpty()) { if (!linkDimension.isEmpty()) {
isMatch = true; isMatch = true;
} }
} }
Set<String> linkMeasure = other.getMeasures().stream().map(mm -> mm.getName()) Set<String> linkMeasure = other.getMeasures().stream().map(Measure::getName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
linkMeasure.retainAll(measures); linkMeasure.retainAll(queryMeasures);
if (!linkMeasure.isEmpty()) { if (!linkMeasure.isEmpty()) {
isMatch = true; isMatch = true;
} }
if (!isMatch && schema.getDimensions().containsKey(other.getName())) { if (!isMatch && ontology.getDimensionMap().containsKey(other.getName())) {
Set<String> linkDimension = schema.getDimensions().get(other.getName()).stream() Set<String> linkDimension = ontology.getDimensionMap().get(other.getName())
.map(dd -> dd.getName()).collect(Collectors.toSet()); .stream().map(dd -> dd.getName()).collect(Collectors.toSet());
linkDimension.retainAll(queryDimension); linkDimension.retainAll(queryDimensions);
if (!linkDimension.isEmpty()) { if (!linkDimension.isEmpty()) {
isMatch = true; isMatch = true;
} }
} }
if (isMatch) { if (isMatch) {
linkDataSourceName.add(other.getName()); joinDataModelNames.add(other.getName());
before.add(other.getName()); before.add(other.getName());
} }
} }
} }
if (!CollectionUtils.isEmpty(linkDataSourceName)) { if (!CollectionUtils.isEmpty(joinDataModelNames)) {
Map<String, Long> orders = new HashMap<>(); Map<String, Long> orders = new HashMap<>();
linkDataSourceName.add(baseDataModel.getName()); joinDataModelNames.add(baseDataModel.getName());
orders.put(baseDataModel.getName(), 0L); orders.put(baseDataModel.getName(), 0L);
for (JoinRelation joinRelation : schema.getJoinRelations()) { for (JoinRelation joinRelation : ontology.getJoinRelations()) {
if (linkDataSourceName.contains(joinRelation.getLeft()) if (joinDataModelNames.contains(joinRelation.getLeft())
&& linkDataSourceName.contains(joinRelation.getRight())) { && joinDataModelNames.contains(joinRelation.getRight())) {
orders.put(joinRelation.getLeft(), 0L); orders.put(joinRelation.getLeft(), 0L);
orders.put(joinRelation.getRight(), 1L); orders.put(joinRelation.getRight(), 1L);
} }
} }
orders.entrySet().stream().sorted(Map.Entry.comparingByValue()).forEach(d -> { orders.entrySet().stream().sorted(Map.Entry.comparingByValue()).forEach(d -> {
linkDataModels.add(schema.getDataModels().get(d.getKey())); joinDataModels.add(ontology.getDataModelMap().get(d.getKey()));
}); });
} }
return linkDataModels;
return joinDataModels;
} }
private static void sortJoinRelation(List<JoinRelation> joinRelations, String next, private static void sortJoinRelation(List<JoinRelation> joinRelations, String next,
@@ -381,12 +368,17 @@ public class DataModelNode extends SemanticNode {
} }
} }
private static List<DataModel> getLinkDataSources(Set<String> baseIdentifiers, private static List<DataModel> findRelatedModelsByIdentifier(Ontology ontology,
Set<String> queryDimension, Set<String> measures, DataModel baseDataModel, DataModel baseDataModel, Set<String> queryDimension, Set<String> measures) {
S2CalciteSchema schema) { Set<String> baseIdentifiers = baseDataModel.getIdentifiers().stream().map(Identify::getName)
.collect(Collectors.toSet());
if (baseIdentifiers.isEmpty()) {
return Collections.EMPTY_LIST;
}
Set<String> linkDataSourceName = new HashSet<>(); Set<String> linkDataSourceName = new HashSet<>();
List<DataModel> linkDataModels = new ArrayList<>(); List<DataModel> linkDataModels = new ArrayList<>();
for (Map.Entry<String, DataModel> entry : schema.getDataModels().entrySet()) { for (Map.Entry<String, DataModel> entry : ontology.getDataModelMap().entrySet()) {
if (entry.getKey().equalsIgnoreCase(baseDataModel.getName())) { if (entry.getKey().equalsIgnoreCase(baseDataModel.getName())) {
continue; continue;
} }
@@ -417,9 +409,9 @@ public class DataModelNode extends SemanticNode {
} }
} }
} }
for (Map.Entry<String, List<Dimension>> entry : schema.getDimensions().entrySet()) { for (Map.Entry<String, List<Dimension>> entry : ontology.getDimensionMap().entrySet()) {
if (!queryDimension.isEmpty()) { if (!queryDimension.isEmpty()) {
Set<String> linkDimension = entry.getValue().stream().map(dd -> dd.getName()) Set<String> linkDimension = entry.getValue().stream().map(Dimension::getName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
linkDimension.retainAll(queryDimension); linkDimension.retainAll(queryDimension);
if (!linkDimension.isEmpty()) { if (!linkDimension.isEmpty()) {
@@ -428,7 +420,7 @@ public class DataModelNode extends SemanticNode {
} }
} }
for (String linkName : linkDataSourceName) { for (String linkName : linkDataSourceName) {
linkDataModels.add(schema.getDataModels().get(linkName)); linkDataModels.add(ontology.getDataModelMap().get(linkName));
} }
if (!CollectionUtils.isEmpty(linkDataModels)) { if (!CollectionUtils.isEmpty(linkDataModels)) {
List<DataModel> all = new ArrayList<>(); List<DataModel> all = new ArrayList<>();
@@ -438,4 +430,5 @@ public class DataModelNode extends SemanticNode {
} }
return Lists.newArrayList(); return Lists.newArrayList();
} }
} }

View File

@@ -60,7 +60,8 @@ public class JoinRender extends Renderer {
} }
Set<String> queryAllDimension = new HashSet<>(); Set<String> queryAllDimension = new HashSet<>();
Set<String> measures = new HashSet<>(); Set<String> measures = new HashSet<>();
DataModelNode.getQueryDimensionMeasure(schema, metricCommand, queryAllDimension, measures); DataModelNode.getQueryDimensionMeasure(schema.getOntology(), metricCommand,
queryAllDimension, measures);
SqlNode left = null; SqlNode left = null;
TableView leftTable = null; TableView leftTable = null;
TableView innerView = new TableView(); TableView innerView = new TableView();

View File

@@ -33,7 +33,6 @@ public class ModelYamlManager {
ModelDetail modelDetail = modelResp.getModelDetail(); ModelDetail modelDetail = modelResp.getModelDetail();
DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType()); DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
SysTimeDimensionBuilder.addSysTimeDimension(modelDetail.getDimensions(), engineAdaptor); SysTimeDimensionBuilder.addSysTimeDimension(modelDetail.getDimensions(), engineAdaptor);
addInterCntMetric(modelResp.getBizName(), modelDetail);
DataModelYamlTpl dataModelYamlTpl = new DataModelYamlTpl(); DataModelYamlTpl dataModelYamlTpl = new DataModelYamlTpl();
dataModelYamlTpl.setType(databaseResp.getType()); dataModelYamlTpl.setType(databaseResp.getType());
BeanUtils.copyProperties(modelDetail, dataModelYamlTpl); BeanUtils.copyProperties(modelDetail, dataModelYamlTpl);

View File

@@ -2,10 +2,8 @@ package com.tencent.supersonic.headless.server.manager;
import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.TagResp;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.*; import com.tencent.supersonic.headless.core.translator.calcite.s2sql.*;
import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Materialization.TimePartType; import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Materialization.TimePartType;
import com.tencent.supersonic.headless.core.translator.calcite.sql.S2CalciteSchema; import com.tencent.supersonic.headless.core.translator.calcite.sql.S2CalciteSchema;
@@ -64,82 +62,6 @@ public class SemanticSchemaManager {
return ontology; return ontology;
} }
public Ontology getTagSemanticModel(SemanticSchemaResp semanticSchemaResp) throws Exception {
if (CollectionUtils.isEmpty(semanticSchemaResp.getTags())) {
throw new Exception("semanticSchemaResp tag is empty");
}
Ontology ontology = buildOntology(semanticSchemaResp);
// Map<String, List<Dimension>> dimensions = new HashMap<>();
Map<Long, List<TagResp>> tagMap = new HashMap<>();
for (TagResp tagResp : semanticSchemaResp.getTags()) {
if (!tagMap.containsKey(tagResp.getModelId())) {
tagMap.put(tagResp.getModelId(), new ArrayList<>());
}
tagMap.get(tagResp.getModelId()).add(tagResp);
}
if (Objects.nonNull(ontology.getDataModelMap()) && !ontology.getDataModelMap().isEmpty()) {
for (Map.Entry<String, DataModel> entry : ontology.getDataModelMap().entrySet()) {
List<Dimension> modelDimensions = new ArrayList<>();
if (!ontology.getDimensionMap().containsKey(entry.getKey())) {
ontology.getDimensionMap().put(entry.getKey(), modelDimensions);
} else {
modelDimensions = ontology.getDimensionMap().get(entry.getKey());
}
if (tagMap.containsKey(entry.getValue().getId())) {
for (TagResp tagResp : tagMap.get(entry.getValue().getId())) {
addTagModel(tagResp, modelDimensions, ontology.getMetrics());
}
}
}
}
return ontology;
}
private void addTagModel(TagResp tagResp, List<Dimension> modelDimensions,
List<Metric> modelMetrics) throws Exception {
TagDefineType tagDefineType = TagDefineType.valueOf(tagResp.getTagDefineType());
switch (tagDefineType) {
case FIELD:
case DIMENSION:
if (TagDefineType.DIMENSION.equals(tagResp.getTagDefineType())) {
Optional<Dimension> modelDimension = modelDimensions.stream()
// .filter(d -> d.getBizName().equals(tagResp.getExpr()))
.findFirst();
if (modelDimension.isPresent()) {
modelDimension.get().setName(tagResp.getBizName());
return;
}
}
Dimension dimension = Dimension.builder().build();
dimension.setType("");
// dimension.setExpr(tagResp.getExpr());
dimension.setName(tagResp.getBizName());
dimension.setOwners("");
dimension.setBizName(tagResp.getBizName());
if (Objects.isNull(dimension.getDataType())) {
dimension.setDataType(DataType.UNKNOWN);
}
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams();
dimension.setDimensionTimeTypeParams(dimensionTimeTypeParams);
modelDimensions.add(dimension);
return;
case METRIC:
Optional<Metric> modelMetric = modelMetrics.stream()
// .filter(m -> m.getName().equalsIgnoreCase(tagResp.getExpr()))
.findFirst();
if (modelMetric.isPresent()) {
modelMetric.get().setName(tagResp.getBizName());
} else {
throw new Exception(
String.format("tag [{}] cant find the metric", tagResp.getBizName()));
}
return;
default:
}
}
public static List<Metric> getMetrics(final List<MetricYamlTpl> t) { public static List<Metric> getMetrics(final List<MetricYamlTpl> t) {
return getMetricsByMetricYamlTpl(t); return getMetricsByMetricYamlTpl(t);
} }

View File

@@ -41,8 +41,8 @@ public class S2CompanyDemo extends S2BaseDemo {
ModelResp model_brand = addModel_2(domain, demoDatabase); ModelResp model_brand = addModel_2(domain, demoDatabase);
ModelResp model_brand_revenue = addModel_3(domain, demoDatabase); ModelResp model_brand_revenue = addModel_3(domain, demoDatabase);
addModelRela(domain, model_company, model_brand, "company_id"); addModelRela(domain, model_brand, model_company, "company_id");
addModelRela(domain, model_brand, model_brand_revenue, "brand_id"); addModelRela(domain, model_brand_revenue, model_brand, "brand_id");
DataSetResp dataset = addDataSet(domain); DataSetResp dataset = addDataSet(domain);
addAgent(dataset.getId()); addAgent(dataset.getId());
@@ -106,8 +106,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures); modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query"); modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT company_id,company_name,headquarter_address," modelDetail.setSqlQuery("SELECT * FROM company");
+ "company_established_time,founder,ceo,annual_turnover,employee_count FROM company");
modelReq.setModelDetail(modelDetail); modelReq.setModelDetail(modelDetail);
ModelResp companyModel = modelService.createModel(modelReq, defaultUser); ModelResp companyModel = modelService.createModel(modelReq, defaultUser);
@@ -146,8 +145,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures); modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query"); modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT brand_id,brand_name,brand_established_time," modelDetail.setSqlQuery("SELECT * FROM brand");
+ "company_id,legal_representative,registered_capital FROM brand");
modelReq.setModelDetail(modelDetail); modelReq.setModelDetail(modelDetail);
ModelResp brandModel = modelService.createModel(modelReq, defaultUser); ModelResp brandModel = modelService.createModel(modelReq, defaultUser);
@@ -187,8 +185,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures); modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query"); modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT year_time,brand_id,revenue,profit," modelDetail.setSqlQuery("SELECT * FROM brand_revenue");
+ "revenue_growth_year_on_year,profit_growth_year_on_year FROM brand_revenue");
modelReq.setModelDetail(modelDetail); modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, defaultUser); return modelService.createModel(modelReq, defaultUser);
} }
@@ -227,7 +224,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelRelaReq.setDomainId(domain.getId()); modelRelaReq.setDomainId(domain.getId());
modelRelaReq.setFromModelId(fromModel.getId()); modelRelaReq.setFromModelId(fromModel.getId());
modelRelaReq.setToModelId(toModel.getId()); modelRelaReq.setToModelId(toModel.getId());
modelRelaReq.setJoinType("left join"); modelRelaReq.setJoinType("inner join");
modelRelaReq.setJoinConditions(joinConditions); modelRelaReq.setJoinConditions(joinConditions);
modelRelaService.save(modelRelaReq, defaultUser); modelRelaService.save(modelRelaReq, defaultUser);
} }

View File

@@ -199,6 +199,7 @@ public class S2VisitsDemo extends S2BaseDemo {
List<Dim> dimensions = new ArrayList<>(); List<Dim> dimensions = new ArrayList<>();
dimensions.add(new Dim("部门", "department", DimensionType.categorical, 1)); dimensions.add(new Dim("部门", "department", DimensionType.categorical, 1));
// dimensions.add(new Dim("用户", "user_name", DimensionType.categorical, 1));
modelDetail.setDimensions(dimensions); modelDetail.setDimensions(dimensions);
List<Field> fields = Lists.newArrayList(); List<Field> fields = Lists.newArrayList();
fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build()); fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build());

View File

@@ -5,7 +5,10 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import com.tencent.supersonic.chat.BaseTest; import com.tencent.supersonic.chat.BaseTest;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.*; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.config.ChatModel; import com.tencent.supersonic.common.config.ChatModel;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
@@ -133,11 +136,28 @@ public class Text2SQLEval extends BaseTest {
assert result.getTextResult().contains("3"); assert result.getTextResult().contains("3");
} }
@Test
public void test_detail_query() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() >= 1;
assert result.getTextResult().contains("Model Y");
assert result.getTextResult().contains("Model 3");
}
public Agent getLLMAgent() { public Agent getLLMAgent() {
Agent agent = new Agent(); Agent agent = new Agent();
agent.setName("Agent for Test"); agent.setName("Agent for Test");
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();
toolConfig.getTools().add(getDatasetTool()); DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(DataUtils.productDatasetId));
toolConfig.getTools().add(datasetTool);
DatasetTool datasetTool2 = new DatasetTool();
datasetTool2.setType(AgentToolType.DATASET);
datasetTool2.setDataSetIds(Lists.newArrayList(DataUtils.companyDatasetId));
toolConfig.getTools().add(datasetTool2);
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// create chat model for this evaluation // create chat model for this evaluation
ChatModel chatModel = new ChatModel(); ChatModel chatModel = new ChatModel();
@@ -154,11 +174,4 @@ public class Text2SQLEval extends BaseTest {
return agent; return agent;
} }
private static DatasetTool getDatasetTool() {
DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(1L));
return datasetTool;
}
} }