From 4e653c1fb1350a0d1a990837b48f02f98d048dd6 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sun, 5 Jan 2025 16:09:42 +0800 Subject: [PATCH] [improvement][headless]Expression replacement logic supports more complex sql. --- .../common/jsqlparser/SqlReplaceHelper.java | 15 ++++++----- .../parser/DimExpressionParser.java | 5 +++- .../parser/MetricExpressionParser.java | 5 +++- .../service/impl/SchemaServiceImpl.java | 1 + .../supersonic/headless/TranslatorTest.java | 14 ++++++++++ .../src/test/resources/sql/testSubquery.sql | 27 +++++++++++++++++++ 6 files changed, 59 insertions(+), 8 deletions(-) create mode 100644 launchers/standalone/src/test/resources/sql/testSubquery.sql diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index 415b2a637..1bd114f9d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -627,7 +627,8 @@ public class SqlReplaceHelper { return expr; } - public static String replaceSqlByExpression(String sql, Map replace) { + public static String replaceSqlByExpression(String tableName, String sql, + Map replace) { Select selectStatement = SqlSelectHelper.getSelect(sql); List plainSelectList = new ArrayList<>(); if (selectStatement instanceof PlainSelect) { @@ -636,9 +637,8 @@ public class SqlReplaceHelper { selectStatement.getWithItemsList().forEach(withItem -> { plainSelectList.add(withItem.getSelect().getPlainSelect()); }); - } else { - plainSelectList.add((PlainSelect) selectStatement); } + plainSelectList.add((PlainSelect) selectStatement); } else if (selectStatement instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectStatement; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { @@ -672,9 +672,12 @@ public class SqlReplaceHelper { List plainSelects = SqlSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect plainSelect : plainSelects) { - replacePlainSelectByExpr(plainSelect, replace); - if (SqlSelectHelper.hasAggregateFunction(plainSelect)) { - SqlSelectHelper.addMissingGroupby(plainSelect); + Table table = (Table) plainSelect.getFromItem(); + if (table.getName().equals(tableName)) { + replacePlainSelectByExpr(plainSelect, replace); + if (SqlSelectHelper.hasAggregateFunction(plainSelect)) { + SqlSelectHelper.addMissingGroupby(plainSelect); + } } } return selectStatement.toString(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/DimExpressionParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/DimExpressionParser.java index aec1835ba..0b94d3f42 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/DimExpressionParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/DimExpressionParser.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.translator.parser; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; import com.tencent.supersonic.headless.core.pojo.OntologyQuery; @@ -40,7 +41,9 @@ public class DimExpressionParser implements QueryParser { Map bizName2Expr = getDimensionExpressions(semanticSchema, ontologyQuery); if (!CollectionUtils.isEmpty(bizName2Expr)) { - String sql = SqlReplaceHelper.replaceSqlByExpression(sqlQuery.getSql(), bizName2Expr); + String sql = SqlReplaceHelper.replaceSqlByExpression( + Constants.TABLE_PREFIX + queryStatement.getDataSetId(), sqlQuery.getSql(), + bizName2Expr); sqlQuery.setSql(sql); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/MetricExpressionParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/MetricExpressionParser.java index 722c31352..7628bc265 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/MetricExpressionParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/MetricExpressionParser.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.translator.parser; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.Measure; import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; @@ -39,7 +40,9 @@ public class MetricExpressionParser implements QueryParser { Map bizName2Expr = getMetricExpressions(semanticSchema, ontologyQuery); if (!CollectionUtils.isEmpty(bizName2Expr)) { - String sql = SqlReplaceHelper.replaceSqlByExpression(sqlQuery.getSql(), bizName2Expr); + String sql = SqlReplaceHelper.replaceSqlByExpression( + Constants.TABLE_PREFIX + queryStatement.getDataSetId(), sqlQuery.getSql(), + bizName2Expr); sqlQuery.setSql(sql); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java index d38bb1bef..361fffe51 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java @@ -347,6 +347,7 @@ public class SchemaServiceImpl implements SchemaService { DataSetSchemaResp dataSetSchemaResp = fetchDataSetSchema(schemaFilterReq.getDataSetId()); BeanUtils.copyProperties(dataSetSchemaResp, semanticSchemaResp); + semanticSchemaResp.setDataSetResp(dataSetSchemaResp); List modelIds = dataSetSchemaResp.getAllModels(); MetaFilter metaFilter = new MetaFilter(); metaFilter.setIds(modelIds); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java index 445a83274..a7978d8fd 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java @@ -111,4 +111,18 @@ public class TranslatorTest extends BaseTest { executeSql(explain.getQuerySQL()); } + @Test + @SetSystemProperty(key = "s2.test", value = "true") + public void testSql_subquery() throws Exception { + String sql = new String( + Files.readAllBytes( + Paths.get(ClassLoader.getSystemResource("sql/testSubquery.sql").toURI())), + StandardCharsets.UTF_8); + SemanticTranslateResp explain = semanticLayerService + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + assertNotNull(explain); + assertNotNull(explain.getQuerySQL()); + executeSql(explain.getQuerySQL()); + } + } diff --git a/launchers/standalone/src/test/resources/sql/testSubquery.sql b/launchers/standalone/src/test/resources/sql/testSubquery.sql new file mode 100644 index 000000000..53b51f5f5 --- /dev/null +++ b/launchers/standalone/src/test/resources/sql/testSubquery.sql @@ -0,0 +1,27 @@ +WITH + _average_stay_duration_ AS ( + SELECT + AVG(停留时长) AS _avg_duration_ + FROM + 超音数数据集 + ) +SELECT + 用户名, + SUM(停留时长) AS _total_stay_duration_ +FROM + 超音数数据集 +GROUP BY + 用户名 +HAVING + SUM(停留时长) > ( + SELECT + _avg_duration_ * 1.5 + FROM + _average_stay_duration_ + ) + OR SUM(停留时长) < ( + SELECT + _avg_duration_ * 0.5 + FROM + _average_stay_duration_ + ) \ No newline at end of file