[improvement][headless]Support dataSetNames that contain dash.
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run

[improvement][headless]Support dataSetNames that contain dash.

[improvement][headless]Support dataSetNames that contain dash.
This commit is contained in:
jerryjzhang
2025-02-21 01:11:58 +08:00
parent 5fa3607874
commit 94e853f57e
7 changed files with 43 additions and 15 deletions

View File

@@ -38,6 +38,10 @@ public class SqlReplaceHelper {
private final static double replaceColumnThreshold = 0.4; private final static double replaceColumnThreshold = 0.4;
public static String escapeTableName(String table) {
return String.format("`%s`", table);
}
public static String replaceAggFields(String sql, public static String replaceAggFields(String sql,
Map<String, Pair<String, String>> fieldNameToAggMap) { Map<String, Pair<String, String>> fieldNameToAggMap) {
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);

View File

@@ -228,7 +228,7 @@ public class SqlSelectHelper {
statement = CCJSqlParserUtil.parse(sql); statement = CCJSqlParserUtil.parse(sql);
} catch (JSQLParserException e) { } catch (JSQLParserException e) {
log.error("parse error, sql:{}", sql, e); log.error("parse error, sql:{}", sql, e);
return null; throw new RuntimeException(e);
} }
if (statement instanceof ParenthesedSelect) { if (statement instanceof ParenthesedSelect) {

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
@@ -281,7 +282,7 @@ public class QueryStructReq extends SemanticQueryReq {
public String getTableName() { public String getTableName() {
if (StringUtils.isNotBlank(dataSetName)) { if (StringUtils.isNotBlank(dataSetName)) {
return dataSetName; return SqlReplaceHelper.escapeTableName(dataSetName);
} }
if (dataSetId != null) { if (dataSetId != null) {
return Constants.TABLE_PREFIX + dataSetId; return Constants.TABLE_PREFIX + dataSetId;

View File

@@ -15,12 +15,19 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.*; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Slf4j @Slf4j
@@ -97,15 +104,16 @@ public class QueryReqBuilder {
* convert to QueryS2SQLReq * convert to QueryS2SQLReq
* *
* @param querySql * @param querySql
* @param dataSetId * @param dataSet
* @return * @return
*/ */
public static QuerySqlReq buildS2SQLReq(String querySql, Long dataSetId) { public static QuerySqlReq buildS2SQLReq(String querySql, DataSetResp dataSet) {
QuerySqlReq querySQLReq = new QuerySqlReq(); QuerySqlReq querySQLReq = new QuerySqlReq();
if (Objects.nonNull(querySql)) { if (Objects.nonNull(querySql)) {
querySQLReq.setSql(querySql); querySQLReq.setSql(querySql);
} }
querySQLReq.setDataSetId(dataSetId); querySQLReq.setDataSetId(dataSet.getId());
querySQLReq.setDataSetName(dataSet.getName());
return querySQLReq; return querySQLReq;
} }

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
import com.tencent.supersonic.auth.api.authorization.service.AuthService; import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.QueryAuthorization; import com.tencent.supersonic.common.pojo.QueryAuthorization;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
@@ -73,6 +74,15 @@ public class S2DataPermissionAspect {
SemanticQueryReq queryReq = null; SemanticQueryReq queryReq = null;
if (objects[0] instanceof SemanticQueryReq) { if (objects[0] instanceof SemanticQueryReq) {
queryReq = (SemanticQueryReq) objects[0]; queryReq = (SemanticQueryReq) objects[0];
if (queryReq instanceof QuerySqlReq) {
QuerySqlReq sqlReq = (QuerySqlReq) queryReq;
if (sqlReq.getDataSetName() != null) {
String escapedTable = SqlReplaceHelper.escapeTableName(sqlReq.getDataSetName());
sqlReq.setSql(sqlReq.getSql().replaceAll(
String.format(" %s ", sqlReq.getDataSetName()),
String.format(" %s ", escapedTable)));
}
}
} }
if (queryReq == null) { if (queryReq == null) {
throw new InvalidArgumentException("queryReq is not Invalid"); throw new InvalidArgumentException("queryReq is not Invalid");

View File

@@ -20,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO; import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository; import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.SchemaService; import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
@@ -46,6 +47,8 @@ public class BaseTest extends BaseApplication {
private AgentService agentService; private AgentService agentService;
@Autowired @Autowired
protected DatabaseService databaseService; protected DatabaseService databaseService;
@Autowired
protected DataSetService dataSetService;
protected Agent agent; protected Agent agent;
protected SemanticSchema schema; protected SemanticSchema schema;

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.demo.S2VisitsDemo; import com.tencent.supersonic.demo.S2VisitsDemo;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@@ -18,14 +19,15 @@ import static org.junit.Assert.assertTrue;
public class TranslatorTest extends BaseTest { public class TranslatorTest extends BaseTest {
private Long dataSetId; private DataSetResp dataSet;
@BeforeEach @BeforeEach
public void init() { public void init() {
agent = getAgentByName(S2VisitsDemo.AGENT_NAME); agent = getAgentByName(S2VisitsDemo.AGENT_NAME);
schema = schemaService.getSemanticSchema(agent.getDataSetIds()); schema = schemaService.getSemanticSchema(agent.getDataSetIds());
if (Objects.nonNull(agent)) { if (Objects.nonNull(agent)) {
dataSetId = agent.getDataSetIds().stream().findFirst().get(); long dataSetId = agent.getDataSetIds().stream().findFirst().get();
dataSet = dataSetService.getDataSet(dataSetId);
} }
} }
@@ -34,7 +36,7 @@ public class TranslatorTest extends BaseTest {
String sql = String sql =
"SELECT SUM(访问次数) AS _总访问次数_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15'"; "SELECT SUM(访问次数) AS _总访问次数_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15'";
SemanticTranslateResp explain = semanticLayerService SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("count(1)")); assertTrue(explain.getQuerySQL().contains("count(1)"));
@@ -45,7 +47,7 @@ public class TranslatorTest extends BaseTest {
public void testSql_1() throws Exception { public void testSql_1() throws Exception {
String sql = "SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; String sql = "SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
SemanticTranslateResp explain = semanticLayerService SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department")); assertTrue(explain.getQuerySQL().contains("department"));
@@ -59,7 +61,7 @@ public class TranslatorTest extends BaseTest {
String sql = String sql =
"WITH _department_visits_ AS (SELECT 部门, SUM(访问次数) AS _total_visits_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15' GROUP BY 部门) SELECT 部门 FROM _department_visits_ ORDER BY _total_visits_ DESC LIMIT 2"; "WITH _department_visits_ AS (SELECT 部门, SUM(访问次数) AS _total_visits_ FROM 超音数数据集 WHERE 数据日期 >= '2024-11-15' AND 数据日期 <= '2024-12-15' GROUP BY 部门) SELECT 部门 FROM _department_visits_ ORDER BY _total_visits_ DESC LIMIT 2";
SemanticTranslateResp explain = semanticLayerService SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().toLowerCase().contains("department")); assertTrue(explain.getQuerySQL().toLowerCase().contains("department"));
@@ -73,7 +75,7 @@ public class TranslatorTest extends BaseTest {
String sql = String sql =
"WITH recent_data AS (SELECT 用户名, 访问次数 FROM 超音数数据集 WHERE 部门 = 'marketing' AND 数据日期 >= '2024-12-01' AND 数据日期 <= '2024-12-15') SELECT 用户名 FROM recent_data ORDER BY 访问次数 DESC LIMIT 1"; "WITH recent_data AS (SELECT 用户名, 访问次数 FROM 超音数数据集 WHERE 部门 = 'marketing' AND 数据日期 >= '2024-12-01' AND 数据日期 <= '2024-12-15') SELECT 用户名 FROM recent_data ORDER BY 访问次数 DESC LIMIT 1";
SemanticTranslateResp explain = semanticLayerService SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().toLowerCase().contains("department")); assertTrue(explain.getQuerySQL().toLowerCase().contains("department"));
@@ -89,7 +91,7 @@ public class TranslatorTest extends BaseTest {
Paths.get(ClassLoader.getSystemResource("sql/testUnion.sql").toURI())), Paths.get(ClassLoader.getSystemResource("sql/testUnion.sql").toURI())),
StandardCharsets.UTF_8); StandardCharsets.UTF_8);
SemanticTranslateResp explain = semanticLayerService SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("user_name")); assertTrue(explain.getQuerySQL().contains("user_name"));
@@ -105,7 +107,7 @@ public class TranslatorTest extends BaseTest {
Paths.get(ClassLoader.getSystemResource("sql/testWith.sql").toURI())), Paths.get(ClassLoader.getSystemResource("sql/testWith.sql").toURI())),
StandardCharsets.UTF_8); StandardCharsets.UTF_8);
SemanticTranslateResp explain = semanticLayerService SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
executeSql(explain.getQuerySQL()); executeSql(explain.getQuerySQL());
@@ -119,7 +121,7 @@ public class TranslatorTest extends BaseTest {
Paths.get(ClassLoader.getSystemResource("sql/testSubquery.sql").toURI())), Paths.get(ClassLoader.getSystemResource("sql/testSubquery.sql").toURI())),
StandardCharsets.UTF_8); StandardCharsets.UTF_8);
SemanticTranslateResp explain = semanticLayerService SemanticTranslateResp explain = semanticLayerService
.translate(QueryReqBuilder.buildS2SQLReq(sql, dataSetId), User.getDefaultUser()); .translate(QueryReqBuilder.buildS2SQLReq(sql, dataSet), User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
executeSql(explain.getQuerySQL()); executeSql(explain.getQuerySQL());