From dad065d0bae020246560401c4f590a812c145491 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Wed, 27 Nov 2024 11:29:29 +0800 Subject: [PATCH] [improvement][headless]Clean code logic of headless translator. --- .../translator/DefaultSemanticTranslator.java | 2 +- .../translator/calcite/s2sql/Ontology.java | 4 - .../translator/calcite/sql/SqlBuilder.java | 16 +- .../calcite/sql/node/DataModelNode.java | 309 +++++++++--------- .../calcite/sql/render/JoinRender.java | 3 +- .../server/manager/ModelYamlManager.java | 1 - .../server/manager/SemanticSchemaManager.java | 78 ----- .../supersonic/demo/S2CompanyDemo.java | 15 +- .../tencent/supersonic/demo/S2VisitsDemo.java | 1 + .../supersonic/evaluation/Text2SQLEval.java | 31 +- 10 files changed, 184 insertions(+), 276 deletions(-) diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java index 662819598..ef3b73e97 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java @@ -50,7 +50,7 @@ public class DefaultSemanticTranslator implements SemanticTranslator { } } catch (Exception e) { queryStatement.setErrMsg(e.getMessage()); - log.error("Failed to translate semantic query [{}]", e.getMessage(), e); + log.error("Failed to translate query [{}]", e.getMessage(), e); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Ontology.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Ontology.java index da2e21698..c44e7225c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Ontology.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Ontology.java @@ -25,8 +25,4 @@ public class Ontology { .collect(Collectors.toList()); } - public Map getModelMap() { - return dataModelMap.values().stream() - .collect(Collectors.toMap(DataModel::getId, dataSource -> dataSource)); - } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/SqlBuilder.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/SqlBuilder.java index 570854a07..54c79e468 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/SqlBuilder.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/SqlBuilder.java @@ -52,7 +52,7 @@ public class SqlBuilder { // find relevant data models scope = SchemaBuilder.getScope(schema); List dataModels = - DataModelNode.getRelatedDataModels(scope, schema, ontologyQueryParam); + DataModelNode.getQueryDataModels(scope, schema, ontologyQueryParam); if (dataModels == null || dataModels.isEmpty()) { throw new Exception("data model not found"); } @@ -98,20 +98,6 @@ public class SqlBuilder { return SemanticNode.getSql(parserNode, engineType); } - private String rewrite(String sql, EngineType engineType) { - try { - SqlNode sqlNode = - SqlParser.create(sql, Configuration.getParserConfig(engineType)).parseStmt(); - if (Objects.nonNull(sqlNode)) { - return SemanticNode.getSql( - SemanticNode.optimize(scope, schema, sqlNode, engineType), engineType); - } - } catch (Exception e) { - log.error("optimize error {}", e.toString()); - } - return ""; - } - private void optimizeParseNode(EngineType engineType) { if (Objects.isNull(schema.getRuntimeOptions()) || Objects.isNull(schema.getRuntimeOptions().getEnableOptimize()) diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DataModelNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DataModelNode.java index 5d4b38b49..b80458045 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DataModelNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DataModelNode.java @@ -4,36 +4,17 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.common.calcite.Configuration; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.pojo.enums.EngineType; -import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Constants; -import com.tencent.supersonic.headless.core.translator.calcite.s2sql.DataModel; -import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Dimension; -import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Identify; -import com.tencent.supersonic.headless.core.translator.calcite.s2sql.JoinRelation; -import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Measure; -import com.tencent.supersonic.headless.core.translator.calcite.s2sql.OntologyQueryParam; +import com.tencent.supersonic.headless.core.translator.calcite.s2sql.*; import com.tencent.supersonic.headless.core.translator.calcite.sql.S2CalciteSchema; import com.tencent.supersonic.headless.core.translator.calcite.sql.SchemaBuilder; import lombok.extern.slf4j.Slf4j; -import org.apache.calcite.sql.SqlBasicCall; -import org.apache.calcite.sql.SqlDataTypeSpec; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlNodeList; -import org.apache.calcite.sql.SqlUserDefinedTypeNameSpec; +import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlValidatorScope; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; @Slf4j @@ -53,7 +34,7 @@ public class DataModelNode extends SemanticNode { } } if (sqlTable.isEmpty()) { - throw new Exception("DatasourceNode build error [tableSqlNode not found]"); + throw new Exception("DataModelNode build error [tableSqlNode not found]"); } SqlNode source = getTable(sqlTable, scope, EngineType.fromString(dataModel.getType())); addSchema(scope, dataModel, sqlTable); @@ -149,166 +130,171 @@ public class DataModelNode extends SemanticNode { return dataModelList.stream().map(d -> d.getName()).collect(Collectors.joining("_")); } - public static void getQueryDimensionMeasure(S2CalciteSchema schema, - OntologyQueryParam queryParam, Set queryDimensions, Set queryMeasures) { + public static void getQueryDimensionMeasure(Ontology ontology, OntologyQueryParam queryParam, + Set queryDimensions, Set queryMeasures) { queryDimensions.addAll(queryParam.getDimensions().stream() .map(d -> d.contains(Constants.DIMENSION_IDENTIFY) ? d.split(Constants.DIMENSION_IDENTIFY)[1] : d) .collect(Collectors.toSet())); Set schemaMetricName = - schema.getMetrics().stream().map(m -> m.getName()).collect(Collectors.toSet()); - schema.getMetrics().stream().filter(m -> queryParam.getMetrics().contains(m.getName())) + ontology.getMetrics().stream().map(m -> m.getName()).collect(Collectors.toSet()); + ontology.getMetrics().stream().filter(m -> queryParam.getMetrics().contains(m.getName())) .forEach(m -> m.getMetricTypeParams().getMeasures().stream() .forEach(mm -> queryMeasures.add(mm.getName()))); queryParam.getMetrics().stream().filter(m -> !schemaMetricName.contains(m)) .forEach(m -> queryMeasures.add(m)); } - public static void mergeQueryFilterDimensionMeasure(S2CalciteSchema schema, - OntologyQueryParam metricCommand, Set queryDimension, Set measures, + public static void mergeQueryFilterDimensionMeasure(Ontology ontology, + OntologyQueryParam queryParam, Set dimensions, Set measures, SqlValidatorScope scope) throws Exception { - EngineType engineType = schema.getOntology().getDatabase().getType(); - if (Objects.nonNull(metricCommand.getWhere()) && !metricCommand.getWhere().isEmpty()) { + EngineType engineType = ontology.getDatabase().getType(); + if (Objects.nonNull(queryParam.getWhere()) && !queryParam.getWhere().isEmpty()) { Set filterConditions = new HashSet<>(); - FilterNode.getFilterField(parse(metricCommand.getWhere(), scope, engineType), + FilterNode.getFilterField(parse(queryParam.getWhere(), scope, engineType), filterConditions); Set queryMeasures = new HashSet<>(measures); - Set schemaMetricName = - schema.getMetrics().stream().map(m -> m.getName()).collect(Collectors.toSet()); + Set schemaMetricName = ontology.getMetrics().stream().map(m -> m.getName()) + .collect(Collectors.toSet()); for (String filterCondition : filterConditions) { if (schemaMetricName.contains(filterCondition)) { - schema.getMetrics().stream() + ontology.getMetrics().stream() .filter(m -> m.getName().equalsIgnoreCase(filterCondition)) .forEach(m -> m.getMetricTypeParams().getMeasures().stream() .forEach(mm -> queryMeasures.add(mm.getName()))); continue; } - queryDimension.add(filterCondition); + dimensions.add(filterCondition); } measures.clear(); measures.addAll(queryMeasures); } } - public static List getRelatedDataModels(SqlValidatorScope scope, + public static List getQueryDataModels(SqlValidatorScope scope, S2CalciteSchema schema, OntologyQueryParam queryParam) throws Exception { - List dataModels = new ArrayList<>(); - - // check by metric + Ontology ontology = schema.getOntology(); + // get query measures and dimensions Set queryMeasures = new HashSet<>(); Set queryDimensions = new HashSet<>(); - getQueryDimensionMeasure(schema, queryParam, queryDimensions, queryMeasures); - DataModel baseDataModel = null; - // one , match measure count - Map dataSourceMeasures = new HashMap<>(); - for (Map.Entry entry : schema.getDataModels().entrySet()) { - Set sourceMeasure = entry.getValue().getMeasures().stream() - .map(mm -> mm.getName()).collect(Collectors.toSet()); - sourceMeasure.retainAll(queryMeasures); - dataSourceMeasures.put(entry.getKey(), sourceMeasure.size()); - } - log.info("metrics: [{}]", dataSourceMeasures); - Optional> base = dataSourceMeasures.entrySet().stream() - .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst(); - if (base.isPresent()) { - baseDataModel = schema.getDataModels().get(base.get().getKey()); - dataModels.add(baseDataModel); - } - // second , check match all dimension and metric - if (baseDataModel != null) { - Set filterMeasure = new HashSet<>(); - Set sourceMeasure = baseDataModel.getMeasures().stream().map(mm -> mm.getName()) - .collect(Collectors.toSet()); - Set dimension = baseDataModel.getDimensions().stream().map(dd -> dd.getName()) - .collect(Collectors.toSet()); - baseDataModel.getIdentifiers().stream().forEach(i -> dimension.add(i.getName())); - if (schema.getDimensions().containsKey(baseDataModel.getName())) { - schema.getDimensions().get(baseDataModel.getName()).stream() - .forEach(d -> dimension.add(d.getName())); - } - filterMeasure.addAll(sourceMeasure); - filterMeasure.addAll(dimension); - EngineType engineType = schema.getOntology().getDatabase().getType(); - mergeQueryFilterDimensionMeasure(schema, queryParam, queryDimensions, queryMeasures, - scope); - boolean isAllMatch = checkMatch(sourceMeasure, queryDimensions, queryMeasures, - dimension, queryParam, scope, engineType); - if (isAllMatch) { - log.debug("baseDataModel match all "); - return dataModels; - } - // find all dataSource has the same identifiers - List linkDataModels = getLinkDataSourcesByJoinRelation(queryDimensions, - queryMeasures, baseDataModel, schema); - if (CollectionUtils.isEmpty(linkDataModels)) { - log.debug("baseDataModel get by identifiers "); - Set baseIdentifiers = baseDataModel.getIdentifiers().stream() - .map(i -> i.getName()).collect(Collectors.toSet()); - if (baseIdentifiers.isEmpty()) { - throw new Exception( - "datasource error : " + baseDataModel.getName() + " miss identifier"); - } - linkDataModels = getLinkDataSources(baseIdentifiers, queryDimensions, queryMeasures, - baseDataModel, schema); - if (linkDataModels.isEmpty()) { - throw new Exception(String.format( - "not find the match datasource : dimension[%s],measure[%s]", + getQueryDimensionMeasure(ontology, queryParam, queryDimensions, queryMeasures); + mergeQueryFilterDimensionMeasure(ontology, queryParam, queryDimensions, queryMeasures, + scope); + + // first, find the base model + DataModel baseDataModel = findBaseModel(ontology, queryMeasures, queryDimensions); + if (Objects.isNull(baseDataModel)) { + throw new RuntimeException( + String.format("could not find matching dataModel, dimensions:%s, measures:%s", queryDimensions, queryMeasures)); - } - } - log.debug("linkDataModels {}", linkDataModels); - return linkDataModels; + } + // if the base model matches all queried measures and dimensions, just return + if (checkMatch(baseDataModel, queryMeasures, queryDimensions)) { + log.debug("baseDataModel match all measures and dimensions"); + return Collections.singletonList(baseDataModel); } - return dataModels; + // second, traverse the ontology to find other related dataModels + List relatedDataModels = findRelatedModelsByRelation(ontology, baseDataModel, + queryDimensions, queryMeasures); + if (CollectionUtils.isEmpty(relatedDataModels)) { + relatedDataModels = findRelatedModelsByIdentifier(ontology, baseDataModel, + queryDimensions, queryMeasures); + } + if (CollectionUtils.isEmpty(relatedDataModels)) { + relatedDataModels = Collections.singletonList(baseDataModel); + } + + log.debug("relatedDataModels {}", relatedDataModels); + return relatedDataModels; } - private static boolean checkMatch(Set sourceMeasure, Set queryDimension, - Set measures, Set dimension, OntologyQueryParam metricCommand, - SqlValidatorScope scope, EngineType engineType) throws Exception { - boolean isAllMatch = true; - sourceMeasure.retainAll(measures); - if (sourceMeasure.size() < measures.size()) { - log.info("baseDataSource measures not match all measure"); - // check dimension again - Set dimensionMeasures = new HashSet<>(); - dimensionMeasures.addAll(dimension); - dimensionMeasures.retainAll(measures); - if (sourceMeasure.size() + dimensionMeasures.size() < measures.size()) { - log.info("baseDataSource not match all measure"); - isAllMatch = false; + private static DataModel findBaseModel(Ontology ontology, Set queryMeasures, + Set queryDimensions) { + DataModel dataModel = null; + // first, try to find the model with the most matching measures + Map dataModelMeasuresCount = new HashMap<>(); + for (Map.Entry entry : ontology.getDataModelMap().entrySet()) { + Set sourceMeasure = entry.getValue().getMeasures().stream() + .map(Measure::getName).collect(Collectors.toSet()); + sourceMeasure.retainAll(queryMeasures); + dataModelMeasuresCount.put(entry.getKey(), sourceMeasure.size()); + } + log.info("dataModelMeasureCount: [{}]", dataModelMeasuresCount); + Optional> base = + dataModelMeasuresCount.entrySet().stream().filter(e -> e.getValue() > 0) + .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst(); + + if (base.isPresent()) { + dataModel = ontology.getDataModelMap().get(base.get().getKey()); + } else { + // second, try to find the model with the most matching dimensions + Map dataModelDimCount = new HashMap<>(); + for (Map.Entry> entry : ontology.getDimensionMap().entrySet()) { + Set modelDimensions = entry.getValue().stream().map(Dimension::getName) + .collect(Collectors.toSet()); + modelDimensions.retainAll(queryDimensions); + dataModelDimCount.put(entry.getKey(), modelDimensions.size()); + } + log.info("dataModelDimCount: [{}]", dataModelDimCount); + base = dataModelDimCount.entrySet().stream().filter(e -> e.getValue() > 0) + .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst(); + if (base.isPresent()) { + dataModel = ontology.getDataModelMap().get(base.get().getKey()); } } - measures.removeAll(sourceMeasure); - dimension.retainAll(queryDimension); - if (dimension.size() < queryDimension.size()) { - log.debug("baseDataSource not match all dimension"); + return dataModel; + } + + private static boolean checkMatch(DataModel baseDataModel, Set queryMeasures, + Set queryDimension) { + boolean isAllMatch = true; + Set baseMeasures = baseDataModel.getMeasures().stream().map(Measure::getName) + .collect(Collectors.toSet()); + Set baseDimensions = baseDataModel.getDimensions().stream().map(Dimension::getName) + .collect(Collectors.toSet()); + baseDataModel.getIdentifiers().stream().forEach(i -> baseDimensions.add(i.getName())); + + baseMeasures.retainAll(queryMeasures); + if (baseMeasures.size() < queryMeasures.size()) { + // check dimension again + Set dimensionMeasures = new HashSet<>(); + dimensionMeasures.addAll(baseDimensions); + dimensionMeasures.retainAll(queryMeasures); + if (baseMeasures.size() + dimensionMeasures.size() < queryMeasures.size()) { + log.info("baseDataModel not match all measures"); + isAllMatch = false; + } + queryMeasures.removeAll(dimensionMeasures); + } + queryMeasures.removeAll(baseMeasures); + + baseDimensions.retainAll(queryDimension); + if (baseDimensions.size() < queryDimension.size()) { + log.debug("baseDataModel not match all dimensions"); isAllMatch = false; } - queryDimension.removeAll(dimension); + queryDimension.removeAll(baseDimensions); - if (metricCommand.getWhere() != null && !metricCommand.getWhere().isEmpty()) { - Set whereFields = new HashSet<>(); - SqlNode sqlNode = parse(metricCommand.getWhere(), scope, engineType); - FilterNode.getFilterField(sqlNode, whereFields); - } return isAllMatch; } - private static List getLinkDataSourcesByJoinRelation(Set queryDimension, - Set measures, DataModel baseDataModel, S2CalciteSchema schema) { - Set linkDataSourceName = new HashSet<>(); - List linkDataModels = new ArrayList<>(); + private static List findRelatedModelsByRelation(Ontology ontology, + DataModel baseDataModel, Set queryDimensions, Set queryMeasures) { + Set joinDataModelNames = new HashSet<>(); + List joinDataModels = new ArrayList<>(); Set before = new HashSet<>(); before.add(baseDataModel.getName()); - if (!CollectionUtils.isEmpty(schema.getJoinRelations())) { + + if (!CollectionUtils.isEmpty(ontology.getJoinRelations())) { Set visitJoinRelations = new HashSet<>(); List sortedJoinRelation = new ArrayList<>(); - sortJoinRelation(schema.getJoinRelations(), baseDataModel.getName(), visitJoinRelations, - sortedJoinRelation); - schema.getJoinRelations().stream().filter(j -> !visitJoinRelations.contains(j.getId())) + sortJoinRelation(ontology.getJoinRelations(), baseDataModel.getName(), + visitJoinRelations, sortedJoinRelation); + ontology.getJoinRelations().stream() + .filter(j -> !visitJoinRelations.contains(j.getId())) .forEach(j -> sortedJoinRelation.add(j)); for (JoinRelation joinRelation : sortedJoinRelation) { if (!before.contains(joinRelation.getLeft()) @@ -317,53 +303,54 @@ public class DataModelNode extends SemanticNode { } boolean isMatch = false; boolean isRight = before.contains(joinRelation.getLeft()); - DataModel other = isRight ? schema.getDataModels().get(joinRelation.getRight()) - : schema.getDataModels().get(joinRelation.getLeft()); - if (!queryDimension.isEmpty()) { + DataModel other = isRight ? ontology.getDataModelMap().get(joinRelation.getRight()) + : ontology.getDataModelMap().get(joinRelation.getLeft()); + if (!queryDimensions.isEmpty()) { Set linkDimension = other.getDimensions().stream() .map(dd -> dd.getName()).collect(Collectors.toSet()); other.getIdentifiers().stream().forEach(i -> linkDimension.add(i.getName())); - linkDimension.retainAll(queryDimension); + linkDimension.retainAll(queryDimensions); if (!linkDimension.isEmpty()) { isMatch = true; } } - Set linkMeasure = other.getMeasures().stream().map(mm -> mm.getName()) + Set linkMeasure = other.getMeasures().stream().map(Measure::getName) .collect(Collectors.toSet()); - linkMeasure.retainAll(measures); + linkMeasure.retainAll(queryMeasures); if (!linkMeasure.isEmpty()) { isMatch = true; } - if (!isMatch && schema.getDimensions().containsKey(other.getName())) { - Set linkDimension = schema.getDimensions().get(other.getName()).stream() - .map(dd -> dd.getName()).collect(Collectors.toSet()); - linkDimension.retainAll(queryDimension); + if (!isMatch && ontology.getDimensionMap().containsKey(other.getName())) { + Set linkDimension = ontology.getDimensionMap().get(other.getName()) + .stream().map(dd -> dd.getName()).collect(Collectors.toSet()); + linkDimension.retainAll(queryDimensions); if (!linkDimension.isEmpty()) { isMatch = true; } } if (isMatch) { - linkDataSourceName.add(other.getName()); + joinDataModelNames.add(other.getName()); before.add(other.getName()); } } } - if (!CollectionUtils.isEmpty(linkDataSourceName)) { + if (!CollectionUtils.isEmpty(joinDataModelNames)) { Map orders = new HashMap<>(); - linkDataSourceName.add(baseDataModel.getName()); + joinDataModelNames.add(baseDataModel.getName()); orders.put(baseDataModel.getName(), 0L); - for (JoinRelation joinRelation : schema.getJoinRelations()) { - if (linkDataSourceName.contains(joinRelation.getLeft()) - && linkDataSourceName.contains(joinRelation.getRight())) { + for (JoinRelation joinRelation : ontology.getJoinRelations()) { + if (joinDataModelNames.contains(joinRelation.getLeft()) + && joinDataModelNames.contains(joinRelation.getRight())) { orders.put(joinRelation.getLeft(), 0L); orders.put(joinRelation.getRight(), 1L); } } orders.entrySet().stream().sorted(Map.Entry.comparingByValue()).forEach(d -> { - linkDataModels.add(schema.getDataModels().get(d.getKey())); + joinDataModels.add(ontology.getDataModelMap().get(d.getKey())); }); } - return linkDataModels; + + return joinDataModels; } private static void sortJoinRelation(List joinRelations, String next, @@ -381,12 +368,17 @@ public class DataModelNode extends SemanticNode { } } - private static List getLinkDataSources(Set baseIdentifiers, - Set queryDimension, Set measures, DataModel baseDataModel, - S2CalciteSchema schema) { + private static List findRelatedModelsByIdentifier(Ontology ontology, + DataModel baseDataModel, Set queryDimension, Set measures) { + Set baseIdentifiers = baseDataModel.getIdentifiers().stream().map(Identify::getName) + .collect(Collectors.toSet()); + if (baseIdentifiers.isEmpty()) { + return Collections.EMPTY_LIST; + } + Set linkDataSourceName = new HashSet<>(); List linkDataModels = new ArrayList<>(); - for (Map.Entry entry : schema.getDataModels().entrySet()) { + for (Map.Entry entry : ontology.getDataModelMap().entrySet()) { if (entry.getKey().equalsIgnoreCase(baseDataModel.getName())) { continue; } @@ -417,9 +409,9 @@ public class DataModelNode extends SemanticNode { } } } - for (Map.Entry> entry : schema.getDimensions().entrySet()) { + for (Map.Entry> entry : ontology.getDimensionMap().entrySet()) { if (!queryDimension.isEmpty()) { - Set linkDimension = entry.getValue().stream().map(dd -> dd.getName()) + Set linkDimension = entry.getValue().stream().map(Dimension::getName) .collect(Collectors.toSet()); linkDimension.retainAll(queryDimension); if (!linkDimension.isEmpty()) { @@ -428,7 +420,7 @@ public class DataModelNode extends SemanticNode { } } for (String linkName : linkDataSourceName) { - linkDataModels.add(schema.getDataModels().get(linkName)); + linkDataModels.add(ontology.getDataModelMap().get(linkName)); } if (!CollectionUtils.isEmpty(linkDataModels)) { List all = new ArrayList<>(); @@ -438,4 +430,5 @@ public class DataModelNode extends SemanticNode { } return Lists.newArrayList(); } + } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/JoinRender.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/JoinRender.java index 531a773b4..50fad152a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/JoinRender.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/JoinRender.java @@ -60,7 +60,8 @@ public class JoinRender extends Renderer { } Set queryAllDimension = new HashSet<>(); Set measures = new HashSet<>(); - DataModelNode.getQueryDimensionMeasure(schema, metricCommand, queryAllDimension, measures); + DataModelNode.getQueryDimensionMeasure(schema.getOntology(), metricCommand, + queryAllDimension, measures); SqlNode left = null; TableView leftTable = null; TableView innerView = new TableView(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/ModelYamlManager.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/ModelYamlManager.java index 3a96b718f..44b2fb6bd 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/ModelYamlManager.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/ModelYamlManager.java @@ -33,7 +33,6 @@ public class ModelYamlManager { ModelDetail modelDetail = modelResp.getModelDetail(); DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType()); SysTimeDimensionBuilder.addSysTimeDimension(modelDetail.getDimensions(), engineAdaptor); - addInterCntMetric(modelResp.getBizName(), modelDetail); DataModelYamlTpl dataModelYamlTpl = new DataModelYamlTpl(); dataModelYamlTpl.setType(databaseResp.getType()); BeanUtils.copyProperties(modelDetail, dataModelYamlTpl); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/SemanticSchemaManager.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/SemanticSchemaManager.java index 5a6228b9e..cb4d828af 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/SemanticSchemaManager.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/SemanticSchemaManager.java @@ -2,10 +2,8 @@ package com.tencent.supersonic.headless.server.manager; import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; -import com.tencent.supersonic.headless.api.pojo.response.TagResp; import com.tencent.supersonic.headless.core.translator.calcite.s2sql.*; import com.tencent.supersonic.headless.core.translator.calcite.s2sql.Materialization.TimePartType; import com.tencent.supersonic.headless.core.translator.calcite.sql.S2CalciteSchema; @@ -64,82 +62,6 @@ public class SemanticSchemaManager { return ontology; } - public Ontology getTagSemanticModel(SemanticSchemaResp semanticSchemaResp) throws Exception { - if (CollectionUtils.isEmpty(semanticSchemaResp.getTags())) { - throw new Exception("semanticSchemaResp tag is empty"); - } - Ontology ontology = buildOntology(semanticSchemaResp); - // Map> dimensions = new HashMap<>(); - Map> tagMap = new HashMap<>(); - for (TagResp tagResp : semanticSchemaResp.getTags()) { - if (!tagMap.containsKey(tagResp.getModelId())) { - tagMap.put(tagResp.getModelId(), new ArrayList<>()); - } - tagMap.get(tagResp.getModelId()).add(tagResp); - } - if (Objects.nonNull(ontology.getDataModelMap()) && !ontology.getDataModelMap().isEmpty()) { - for (Map.Entry entry : ontology.getDataModelMap().entrySet()) { - List modelDimensions = new ArrayList<>(); - if (!ontology.getDimensionMap().containsKey(entry.getKey())) { - ontology.getDimensionMap().put(entry.getKey(), modelDimensions); - } else { - modelDimensions = ontology.getDimensionMap().get(entry.getKey()); - } - if (tagMap.containsKey(entry.getValue().getId())) { - for (TagResp tagResp : tagMap.get(entry.getValue().getId())) { - addTagModel(tagResp, modelDimensions, ontology.getMetrics()); - } - } - } - } - - return ontology; - } - - private void addTagModel(TagResp tagResp, List modelDimensions, - List modelMetrics) throws Exception { - TagDefineType tagDefineType = TagDefineType.valueOf(tagResp.getTagDefineType()); - switch (tagDefineType) { - case FIELD: - case DIMENSION: - if (TagDefineType.DIMENSION.equals(tagResp.getTagDefineType())) { - Optional modelDimension = modelDimensions.stream() - // .filter(d -> d.getBizName().equals(tagResp.getExpr())) - .findFirst(); - if (modelDimension.isPresent()) { - modelDimension.get().setName(tagResp.getBizName()); - return; - } - } - Dimension dimension = Dimension.builder().build(); - dimension.setType(""); - // dimension.setExpr(tagResp.getExpr()); - dimension.setName(tagResp.getBizName()); - dimension.setOwners(""); - dimension.setBizName(tagResp.getBizName()); - if (Objects.isNull(dimension.getDataType())) { - dimension.setDataType(DataType.UNKNOWN); - } - - DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams(); - dimension.setDimensionTimeTypeParams(dimensionTimeTypeParams); - modelDimensions.add(dimension); - return; - case METRIC: - Optional modelMetric = modelMetrics.stream() - // .filter(m -> m.getName().equalsIgnoreCase(tagResp.getExpr())) - .findFirst(); - if (modelMetric.isPresent()) { - modelMetric.get().setName(tagResp.getBizName()); - } else { - throw new Exception( - String.format("tag [{}] cant find the metric", tagResp.getBizName())); - } - return; - default: - } - } - public static List getMetrics(final List t) { return getMetricsByMetricYamlTpl(t); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2CompanyDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2CompanyDemo.java index df425a2c9..6c19cec37 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2CompanyDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2CompanyDemo.java @@ -41,8 +41,8 @@ public class S2CompanyDemo extends S2BaseDemo { ModelResp model_brand = addModel_2(domain, demoDatabase); ModelResp model_brand_revenue = addModel_3(domain, demoDatabase); - addModelRela(domain, model_company, model_brand, "company_id"); - addModelRela(domain, model_brand, model_brand_revenue, "brand_id"); + addModelRela(domain, model_brand, model_company, "company_id"); + addModelRela(domain, model_brand_revenue, model_brand, "brand_id"); DataSetResp dataset = addDataSet(domain); addAgent(dataset.getId()); @@ -106,8 +106,7 @@ public class S2CompanyDemo extends S2BaseDemo { modelDetail.setMeasures(measures); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery("SELECT company_id,company_name,headquarter_address," - + "company_established_time,founder,ceo,annual_turnover,employee_count FROM company"); + modelDetail.setSqlQuery("SELECT * FROM company"); modelReq.setModelDetail(modelDetail); ModelResp companyModel = modelService.createModel(modelReq, defaultUser); @@ -146,8 +145,7 @@ public class S2CompanyDemo extends S2BaseDemo { modelDetail.setMeasures(measures); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery("SELECT brand_id,brand_name,brand_established_time," - + "company_id,legal_representative,registered_capital FROM brand"); + modelDetail.setSqlQuery("SELECT * FROM brand"); modelReq.setModelDetail(modelDetail); ModelResp brandModel = modelService.createModel(modelReq, defaultUser); @@ -187,8 +185,7 @@ public class S2CompanyDemo extends S2BaseDemo { modelDetail.setMeasures(measures); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery("SELECT year_time,brand_id,revenue,profit," - + "revenue_growth_year_on_year,profit_growth_year_on_year FROM brand_revenue"); + modelDetail.setSqlQuery("SELECT * FROM brand_revenue"); modelReq.setModelDetail(modelDetail); return modelService.createModel(modelReq, defaultUser); } @@ -227,7 +224,7 @@ public class S2CompanyDemo extends S2BaseDemo { modelRelaReq.setDomainId(domain.getId()); modelRelaReq.setFromModelId(fromModel.getId()); modelRelaReq.setToModelId(toModel.getId()); - modelRelaReq.setJoinType("left join"); + modelRelaReq.setJoinType("inner join"); modelRelaReq.setJoinConditions(joinConditions); modelRelaService.save(modelRelaReq, defaultUser); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 35896939e..69e04f2ae 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -199,6 +199,7 @@ public class S2VisitsDemo extends S2BaseDemo { List dimensions = new ArrayList<>(); dimensions.add(new Dim("部门", "department", DimensionType.categorical, 1)); + // dimensions.add(new Dim("用户", "user_name", DimensionType.categorical, 1)); modelDetail.setDimensions(dimensions); List fields = Lists.newArrayList(); fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build()); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 21a6869c6..4dd9dfb7d 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -5,7 +5,10 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.tencent.supersonic.chat.BaseTest; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; -import com.tencent.supersonic.chat.server.agent.*; +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.agent.AgentToolType; +import com.tencent.supersonic.chat.server.agent.DatasetTool; +import com.tencent.supersonic.chat.server.agent.ToolConfig; import com.tencent.supersonic.common.config.ChatModel; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.User; @@ -133,11 +136,28 @@ public class Text2SQLEval extends BaseTest { assert result.getTextResult().contains("3"); } + @Test + public void test_detail_query() throws Exception { + long start = System.currentTimeMillis(); + QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agentId); + durations.add(System.currentTimeMillis() - start); + assert result.getQueryColumns().size() >= 1; + assert result.getTextResult().contains("Model Y"); + assert result.getTextResult().contains("Model 3"); + } + public Agent getLLMAgent() { Agent agent = new Agent(); agent.setName("Agent for Test"); ToolConfig toolConfig = new ToolConfig(); - toolConfig.getTools().add(getDatasetTool()); + DatasetTool datasetTool = new DatasetTool(); + datasetTool.setType(AgentToolType.DATASET); + datasetTool.setDataSetIds(Lists.newArrayList(DataUtils.productDatasetId)); + toolConfig.getTools().add(datasetTool); + DatasetTool datasetTool2 = new DatasetTool(); + datasetTool2.setType(AgentToolType.DATASET); + datasetTool2.setDataSetIds(Lists.newArrayList(DataUtils.companyDatasetId)); + toolConfig.getTools().add(datasetTool2); agent.setToolConfig(JSONObject.toJSONString(toolConfig)); // create chat model for this evaluation ChatModel chatModel = new ChatModel(); @@ -154,11 +174,4 @@ public class Text2SQLEval extends BaseTest { return agent; } - private static DatasetTool getDatasetTool() { - DatasetTool datasetTool = new DatasetTool(); - datasetTool.setType(AgentToolType.DATASET); - datasetTool.setDataSetIds(Lists.newArrayList(1L)); - - return datasetTool; - } }