From ade03627ce0f89820f3dfdeb3cf26537e918a0dd Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Fri, 27 Dec 2024 11:25:33 +0800 Subject: [PATCH] [improvement][headless]Support s2sql with union all statements. --- .../common/jsqlparser/SqlReplaceHelper.java | 39 ++++++++++++++ .../service/impl/DatabaseServiceImpl.java | 2 +- .../tencent/supersonic/headless/BaseTest.java | 14 +++++ .../supersonic/headless/TranslatorTest.java | 54 +++++++++++++------ .../src/test/resources/sql/testUnion.sql | 34 ++++++++++++ .../src/test/resources/sql/testWith.sql | 29 ++++++++++ 6 files changed, 154 insertions(+), 18 deletions(-) create mode 100644 launchers/standalone/src/test/resources/sql/testUnion.sql create mode 100644 launchers/standalone/src/test/resources/sql/testWith.sql diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index 1fec3a61e..0adffc8ab 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -229,6 +229,26 @@ public class SqlReplaceHelper { orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace)); } } + List selects = setOperationList.getSelects(); + if (!CollectionUtils.isEmpty(selects)) { + for (Select select : selects) { + if (select instanceof PlainSelect) { + plainSelectList.add((PlainSelect) select); + } + } + } + List withItems = setOperationList.getWithItemsList(); + if (!CollectionUtils.isEmpty(withItems)) { + for (WithItem withItem : withItems) { + Select select = withItem.getSelect(); + if (select instanceof PlainSelect) { + plainSelectList.add((PlainSelect) select); + } else if (select instanceof ParenthesedSelect) { + plainSelectList.add(select.getPlainSelect()); + } + } + } } else { return sql; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java index 3ac8712d9..b0288b4eb 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java @@ -190,7 +190,7 @@ public class DatabaseServiceImpl extends ServiceImpl agent = agentService.getAgents().stream() @@ -62,6 +66,16 @@ public class BaseTest extends BaseApplication { return semanticLayerService.queryByReq(buildQuerySqlReq(sql), user); } + protected void executeSql(String sql) { + if (databaseResp == null) { + databaseResp = databaseService.getDatabase(1L); + } + SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp); + assert StringUtils.isBlank(queryResp.getErrorMsg()); + System.out.println( + String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList()))); + } + protected SemanticQueryReq buildQuerySqlReq(String sql) { QuerySqlReq querySqlCmd = new QuerySqlReq(); querySqlCmd.setSql(sql); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java index 1814299ee..445a83274 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java @@ -1,18 +1,17 @@ package com.tencent.supersonic.headless; import com.tencent.supersonic.common.pojo.User; -import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.demo.S2VisitsDemo; -import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; -import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; -import org.apache.commons.lang3.StringUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junitpioneer.jupiter.SetSystemProperty; -import java.util.Optional; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Objects; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -21,22 +20,13 @@ public class TranslatorTest extends BaseTest { private Long dataSetId; - private DatabaseResp databaseResp; - @BeforeEach public void init() { agent = getAgentByName(S2VisitsDemo.AGENT_NAME); schema = schemaService.getSemanticSchema(agent.getDataSetIds()); - Optional id = agent.getDataSetIds().stream().findFirst(); - dataSetId = id.orElse(1L); - databaseResp = databaseService.getDatabase(1L); - } - - private void executeSql(String sql) { - SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp); - assert StringUtils.isBlank(queryResp.getErrorMsg()); - System.out.println( - String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList()))); + if (Objects.nonNull(agent)) { + dataSetId = agent.getDataSetIds().stream().findFirst().get(); + } } @Test @@ -91,4 +81,34 @@ public class TranslatorTest extends BaseTest { executeSql(explain.getQuerySQL()); } + @Test + @SetSystemProperty(key = "s2.test", value = "true") + public void testSql_unionALL() throws Exception { + String sql = new String( + Files.readAllBytes( + Paths.get(ClassLoader.getSystemResource("sql/testUnion.sql").toURI())), + StandardCharsets.UTF_8); + SemanticTranslateResp explain = semanticLayerService + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + assertNotNull(explain); + assertNotNull(explain.getQuerySQL()); + assertTrue(explain.getQuerySQL().contains("department")); + assertTrue(explain.getQuerySQL().contains("pv")); + executeSql(explain.getQuerySQL()); + } + + @Test + @SetSystemProperty(key = "s2.test", value = "true") + public void testSql_with() throws Exception { + String sql = new String( + Files.readAllBytes( + Paths.get(ClassLoader.getSystemResource("sql/testWith.sql").toURI())), + StandardCharsets.UTF_8); + SemanticTranslateResp explain = semanticLayerService + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + assertNotNull(explain); + assertNotNull(explain.getQuerySQL()); + executeSql(explain.getQuerySQL()); + } + } diff --git a/launchers/standalone/src/test/resources/sql/testUnion.sql b/launchers/standalone/src/test/resources/sql/testUnion.sql new file mode 100644 index 000000000..5bb775831 --- /dev/null +++ b/launchers/standalone/src/test/resources/sql/testUnion.sql @@ -0,0 +1,34 @@ +WITH + recent_week AS ( + SELECT + SUM(访问次数) AS _访问次数_, + COUNT(DISTINCT 用户名) AS _访问用户数_ + FROM + 超音数数据集 + WHERE + 数据日期 >= '2024-12-20' + AND 数据日期 <= '2024-12-27' + ), + first_week_december AS ( + SELECT + SUM(访问次数) AS _访问次数_, + COUNT(DISTINCT 用户名) AS _访问用户数_ + FROM + 超音数数据集 + WHERE + 数据日期 >= '2024-12-01' + AND 数据日期 <= '2024-12-07' + ) +SELECT + '最近7天' AS _时间段_, + _访问次数_, + _访问用户数_ +FROM + recent_week +UNION ALL +SELECT + '12月第一个星期' AS _时间段_, + _访问次数_, + _访问用户数_ +FROM + first_week_december \ No newline at end of file diff --git a/launchers/standalone/src/test/resources/sql/testWith.sql b/launchers/standalone/src/test/resources/sql/testWith.sql new file mode 100644 index 000000000..7ef9596f4 --- /dev/null +++ b/launchers/standalone/src/test/resources/sql/testWith.sql @@ -0,0 +1,29 @@ +WITH + weekly_visits AS ( + SELECT + YEAR (数据日期) AS _year_, + WEEK (数据日期) AS _week_, + SUM(访问次数) AS total_visits + FROM + 超音数数据集 + WHERE + ( + 数据日期 >= '2024-11-18' + AND 数据日期 <= '2024-11-25' + ) + GROUP BY + YEAR (数据日期), + WEEK (数据日期) + ) +SELECT + _year_, + _week_, + total_visits +FROM + weekly_visits +WHERE + (_year_ = YEAR (CURRENT_DATE)) +ORDER BY + total_visits DESC +LIMIT + 1 \ No newline at end of file