[improvement][headless]Support s2sql with union all statements.

This commit is contained in:
jerryjzhang
2024-12-27 11:25:33 +08:00
parent ce9ae1c0c1
commit ade03627ce
6 changed files with 154 additions and 18 deletions

View File

@@ -229,6 +229,26 @@ public class SqlReplaceHelper {
orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace)); orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace));
} }
} }
List<Select> selects = operationList.getSelects();
if (!CollectionUtils.isEmpty(selects)) {
for (Select select : selects) {
if (select instanceof PlainSelect) {
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, (PlainSelect) select);
}
}
}
List<WithItem> withItems = operationList.getWithItemsList();
if (!CollectionUtils.isEmpty(withItems)) {
for (WithItem withItem : withItems) {
Select select = withItem.getSelect();
if (select instanceof PlainSelect) {
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, (PlainSelect) select);
} else if (select instanceof ParenthesedSelect) {
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace,
select.getPlainSelect());
}
}
}
} }
public static String replaceFunction(String sql, Map<String, String> functionMap) { public static String replaceFunction(String sql, Map<String, String> functionMap) {
@@ -610,6 +630,25 @@ public class SqlReplaceHelper {
plainSelectList.add(subPlainSelect); plainSelectList.add(subPlainSelect);
}); });
} }
List<Select> selects = setOperationList.getSelects();
if (!CollectionUtils.isEmpty(selects)) {
for (Select select : selects) {
if (select instanceof PlainSelect) {
plainSelectList.add((PlainSelect) select);
}
}
}
List<WithItem> withItems = setOperationList.getWithItemsList();
if (!CollectionUtils.isEmpty(withItems)) {
for (WithItem withItem : withItems) {
Select select = withItem.getSelect();
if (select instanceof PlainSelect) {
plainSelectList.add((PlainSelect) select);
} else if (select instanceof ParenthesedSelect) {
plainSelectList.add(select.getPlainSelect());
}
}
}
} else { } else {
return sql; return sql;
} }

View File

@@ -190,7 +190,7 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
private SemanticQueryResp queryWithColumns(String sql, DatabaseResp database) { private SemanticQueryResp queryWithColumns(String sql, DatabaseResp database) {
SemanticQueryResp queryResultWithColumns = new SemanticQueryResp(); SemanticQueryResp queryResultWithColumns = new SemanticQueryResp();
SqlUtils sqlUtils = this.sqlUtils.init(database); SqlUtils sqlUtils = this.sqlUtils.init(database);
log.info("query SQL: {}", sql); log.info("query SQL: {}", StringUtils.normalizeSpace(sql));
sqlUtils.queryInternal(sql, queryResultWithColumns); sqlUtils.queryInternal(sql, queryResultWithColumns);
return queryResultWithColumns; return queryResultWithColumns;
} }

View File

@@ -10,10 +10,12 @@ import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO; import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
@@ -22,6 +24,7 @@ import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.SchemaService; import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import java.util.ArrayList; import java.util.ArrayList;
@@ -46,6 +49,7 @@ public class BaseTest extends BaseApplication {
protected Agent agent; protected Agent agent;
protected SemanticSchema schema; protected SemanticSchema schema;
protected DatabaseResp databaseResp;
protected Agent getAgentByName(String agentName) { protected Agent getAgentByName(String agentName) {
Optional<Agent> agent = agentService.getAgents().stream() Optional<Agent> agent = agentService.getAgents().stream()
@@ -62,6 +66,16 @@ public class BaseTest extends BaseApplication {
return semanticLayerService.queryByReq(buildQuerySqlReq(sql), user); return semanticLayerService.queryByReq(buildQuerySqlReq(sql), user);
} }
protected void executeSql(String sql) {
if (databaseResp == null) {
databaseResp = databaseService.getDatabase(1L);
}
SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp);
assert StringUtils.isBlank(queryResp.getErrorMsg());
System.out.println(
String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList())));
}
protected SemanticQueryReq buildQuerySqlReq(String sql) { protected SemanticQueryReq buildQuerySqlReq(String sql) {
QuerySqlReq querySqlCmd = new QuerySqlReq(); QuerySqlReq querySqlCmd = new QuerySqlReq();
querySqlCmd.setSql(sql); querySqlCmd.setSql(sql);

View File

@@ -1,18 +1,17 @@
package com.tencent.supersonic.headless; package com.tencent.supersonic.headless;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.demo.S2VisitsDemo; import com.tencent.supersonic.demo.S2VisitsDemo;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.SetSystemProperty; import org.junitpioneer.jupiter.SetSystemProperty;
import java.util.Optional; import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Objects;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@@ -21,22 +20,13 @@ public class TranslatorTest extends BaseTest {
private Long dataSetId; private Long dataSetId;
private DatabaseResp databaseResp;
@BeforeEach @BeforeEach
public void init() { public void init() {
agent = getAgentByName(S2VisitsDemo.AGENT_NAME); agent = getAgentByName(S2VisitsDemo.AGENT_NAME);
schema = schemaService.getSemanticSchema(agent.getDataSetIds()); schema = schemaService.getSemanticSchema(agent.getDataSetIds());
Optional<Long> id = agent.getDataSetIds().stream().findFirst(); if (Objects.nonNull(agent)) {
dataSetId = id.orElse(1L); dataSetId = agent.getDataSetIds().stream().findFirst().get();
databaseResp = databaseService.getDatabase(1L); }
}
private void executeSql(String sql) {
SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp);
assert StringUtils.isBlank(queryResp.getErrorMsg());
System.out.println(
String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList())));
} }
@Test @Test
@@ -91,4 +81,34 @@ public class TranslatorTest extends BaseTest {
executeSql(explain.getQuerySQL()); executeSql(explain.getQuerySQL());
} }
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testSql_unionALL() throws Exception {
String sql = new String(
Files.readAllBytes(
Paths.get(ClassLoader.getSystemResource("sql/testUnion.sql").toURI())),
StandardCharsets.UTF_8);
SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getQuerySQL().contains("pv"));
executeSql(explain.getQuerySQL());
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testSql_with() throws Exception {
String sql = new String(
Files.readAllBytes(
Paths.get(ClassLoader.getSystemResource("sql/testWith.sql").toURI())),
StandardCharsets.UTF_8);
SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
executeSql(explain.getQuerySQL());
}
} }

View File

@@ -0,0 +1,34 @@
WITH
recent_week AS (
SELECT
SUM(访) AS _访问次数_,
COUNT(DISTINCT ) AS _访问用户数_
FROM
WHERE
>= '2024-12-20'
AND <= '2024-12-27'
),
first_week_december AS (
SELECT
SUM(访) AS _访问次数_,
COUNT(DISTINCT ) AS _访问用户数_
FROM
WHERE
>= '2024-12-01'
AND <= '2024-12-07'
)
SELECT
'最近7天' AS _时间段_,
_访问次数_,
_访问用户数_
FROM
recent_week
UNION ALL
SELECT
'12月第一个星期' AS _时间段_,
_访问次数_,
_访问用户数_
FROM
first_week_december

View File

@@ -0,0 +1,29 @@
WITH
weekly_visits AS (
SELECT
YEAR () AS _year_,
WEEK () AS _week_,
SUM(访) AS total_visits
FROM
WHERE
(
>= '2024-11-18'
AND <= '2024-11-25'
)
GROUP BY
YEAR (),
WEEK ()
)
SELECT
_year_,
_week_,
total_visits
FROM
weekly_visits
WHERE
(_year_ = YEAR (CURRENT_DATE))
ORDER BY
total_visits DESC
LIMIT
1