diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index 7f619b3c9..1e5bd50c3 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -38,6 +38,10 @@ public class SqlReplaceHelper { private final static double replaceColumnThreshold = 0.4; + public static String escapeTableName(String table) { + return String.format("`%s`", table); + } + public static String replaceAggFields(String sql, Map> fieldNameToAggMap) { Select selectStatement = SqlSelectHelper.getSelect(sql); diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index 63bbf06b2..11666a86e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -228,7 +228,7 @@ public class SqlSelectHelper { statement = CCJSqlParserUtil.parse(sql); } catch (JSQLParserException e) { log.error("parse error, sql:{}", sql, e); - return null; + throw new RuntimeException(e); } if (statement instanceof ParenthesedSelect) { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java index 2880111f2..55b3bbbf3 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Lists; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; +import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DateConf; @@ -281,7 +282,7 @@ public class QueryStructReq extends SemanticQueryReq { public String getTableName() { if (StringUtils.isNotBlank(dataSetName)) { - return dataSetName; + return SqlReplaceHelper.escapeTableName(dataSetName); } if (dataSetId != null) { return Constants.TABLE_PREFIX + dataSetId; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java index 22f21f700..dddcee601 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java @@ -15,12 +15,19 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; +import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; import org.springframework.util.CollectionUtils; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; @Slf4j @@ -97,15 +104,16 @@ public class QueryReqBuilder { * convert to QueryS2SQLReq * * @param querySql - * @param dataSetId + * @param dataSet * @return */ - public static QuerySqlReq buildS2SQLReq(String querySql, Long dataSetId) { + public static QuerySqlReq buildS2SQLReq(String querySql, DataSetResp dataSet) { QuerySqlReq querySQLReq = new QuerySqlReq(); if (Objects.nonNull(querySql)) { querySQLReq.setSql(querySql); } - querySQLReq.setDataSetId(dataSetId); + querySQLReq.setDataSetId(dataSet.getId()); + querySQLReq.setDataSetName(dataSet.getName()); return querySQLReq; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java index 2cb623acc..a4836182b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; import com.tencent.supersonic.auth.api.authorization.service.AuthService; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; +import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.common.pojo.QueryAuthorization; import com.tencent.supersonic.common.pojo.User; @@ -73,6 +74,15 @@ public class S2DataPermissionAspect { SemanticQueryReq queryReq = null; if (objects[0] instanceof SemanticQueryReq) { queryReq = (SemanticQueryReq) objects[0]; + if (queryReq instanceof QuerySqlReq) { + QuerySqlReq sqlReq = (QuerySqlReq) queryReq; + if (sqlReq.getDataSetName() != null) { + String escapedTable = SqlReplaceHelper.escapeTableName(sqlReq.getDataSetName()); + sqlReq.setSql(sqlReq.getSql().replaceAll( + String.format(" %s ", sqlReq.getDataSetName()), + String.format(" %s ", escapedTable))); + } + } } if (queryReq == null) { throw new InvalidArgumentException("queryReq is not Invalid"); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java index df61c7d43..3a29e9d47 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java @@ -20,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO; import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository; +import com.tencent.supersonic.headless.server.service.DataSetService; import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.SchemaService; import com.tencent.supersonic.util.DataUtils; @@ -46,6 +47,8 @@ public class BaseTest extends BaseApplication { private AgentService agentService; @Autowired protected DatabaseService databaseService; + @Autowired + protected DataSetService dataSetService; protected Agent agent; protected SemanticSchema schema; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java index 127e29d6a..3e47c5b94 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslatorTest.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.demo.S2VisitsDemo; +import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import org.junit.jupiter.api.BeforeEach; @@ -18,14 +19,15 @@ import static org.junit.Assert.assertTrue; public class TranslatorTest extends BaseTest { - private Long dataSetId; + private DataSetResp dataSet; @BeforeEach public void init() { agent = getAgentByName(S2VisitsDemo.AGENT_NAME); schema = schemaService.getSemanticSchema(agent.getDataSetIds()); if (Objects.nonNull(agent)) { - dataSetId = agent.getDataSetIds().stream().findFirst().get(); + long dataSetId = agent.getDataSetIds().stream().findFirst().get(); + dataSet = dataSetService.getDataSet(dataSetId); } } @@ -34,7 +36,7 @@ public class TranslatorTest extends BaseTest { String sql = "SELECT SUM(访问次数) AS _总访问次数_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15'"; SemanticTranslateResp explain = semanticLayerService - .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser()); assertNotNull(explain); assertNotNull(explain.getQuerySQL()); assertTrue(explain.getQuerySQL().contains("count(1)")); @@ -45,7 +47,7 @@ public class TranslatorTest extends BaseTest { public void testSql_1() throws Exception { String sql = "SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; SemanticTranslateResp explain = semanticLayerService - .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser()); assertNotNull(explain); assertNotNull(explain.getQuerySQL()); assertTrue(explain.getQuerySQL().contains("department")); @@ -59,7 +61,7 @@ public class TranslatorTest extends BaseTest { String sql = "WITH _department_visits_ AS (SELECT 部门, SUM(访问次数) AS _total_visits_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15' GROUP BY 部门) SELECT 部门 FROM _department_visits_ ORDER BY _total_visits_ DESC LIMIT 2"; SemanticTranslateResp explain = semanticLayerService - .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser()); assertNotNull(explain); assertNotNull(explain.getQuerySQL()); assertTrue(explain.getQuerySQL().toLowerCase().contains("department")); @@ -73,7 +75,7 @@ public class TranslatorTest extends BaseTest { String sql = "WITH recent_data AS (SELECT 用户名, 访问次数 FROM 超音数数据集 WHERE 部门 = 'marketing' AND 数据日期 >= '2024-12-01' AND 数据日期 <= '2024-12-15') SELECT 用户名 FROM recent_data ORDER BY 访问次数 DESC LIMIT 1"; SemanticTranslateResp explain = semanticLayerService - .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser()); assertNotNull(explain); assertNotNull(explain.getQuerySQL()); assertTrue(explain.getQuerySQL().toLowerCase().contains("department")); @@ -89,7 +91,7 @@ public class TranslatorTest extends BaseTest { Paths.get(ClassLoader.getSystemResource("sql/testUnion.sql").toURI())), StandardCharsets.UTF_8); SemanticTranslateResp explain = semanticLayerService - .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser()); assertNotNull(explain); assertNotNull(explain.getQuerySQL()); assertTrue(explain.getQuerySQL().contains("user_name")); @@ -105,7 +107,7 @@ public class TranslatorTest extends BaseTest { Paths.get(ClassLoader.getSystemResource("sql/testWith.sql").toURI())), StandardCharsets.UTF_8); SemanticTranslateResp explain = semanticLayerService - .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser()); assertNotNull(explain); assertNotNull(explain.getQuerySQL()); executeSql(explain.getQuerySQL()); @@ -119,7 +121,7 @@ public class TranslatorTest extends BaseTest { Paths.get(ClassLoader.getSystemResource("sql/testSubquery.sql").toURI())), StandardCharsets.UTF_8); SemanticTranslateResp explain = semanticLayerService - .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); + .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser()); assertNotNull(explain); assertNotNull(explain.getQuerySQL()); executeSql(explain.getQuerySQL());