From b122053e98b06911985b2425cacc91d1ded51d37 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 9 May 2024 10:47:58 +0800 Subject: [PATCH] (improvement)(Headless) Remove backticks when querying through SQL. (#968) --- .../supersonic/common/util/StringUtil.java | 6 ++++-- .../common/util/jsqlparser/SqlSelectHelper.java | 16 +++++++++------- .../util/jsqlparser/SqlSelectHelperTest.java | 12 +++++++++++- .../server/rest/api/SqlQueryApiController.java | 6 +++++- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java b/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java index f3a28c647..c7b3f3f08 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java @@ -26,10 +26,9 @@ public class StringUtil { } /** - * * @param v1 * @param v2 - * @return value 0 if v1 equal to v2; less than 0 if v1 is less than v2; greater than 0 if v1 is greater than v2 + * @return value 0 if v1 equal to v2; less than 0 if v1 is less than v2; greater than 0 if v1 is greater than v2 */ public static int compareVersion(String v1, String v2) { String[] v1s = v1.split("\\."); @@ -46,4 +45,7 @@ public class StringUtil { return v1s.length - v2s.length; } + public static String replaceBackticks(String sql) { + return sql.replaceAll("`", ""); + } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java index 66099c79c..b5aec7c5a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java @@ -1,11 +1,6 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; +import com.tencent.supersonic.common.util.StringUtil; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Alias; @@ -44,6 +39,13 @@ import net.sf.jsqlparser.statement.select.SetOperationList; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + /** * Sql Parser Select Helper */ @@ -363,7 +365,7 @@ public class SqlSelectHelper { public static String getTableName(String sql) { Table table = getTable(sql); - return table.getName(); + return StringUtil.replaceBackticks(table.getName()); } public static List getAggregateFields(String sql) { diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelperTest.java index 6693850e5..a7728f253 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelperTest.java @@ -1,11 +1,12 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.List; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.statement.select.Select; import org.junit.Assert; import org.junit.jupiter.api.Test; +import java.util.List; + /** * SqlParserSelectHelper Test */ @@ -272,4 +273,13 @@ class SqlSelectHelperTest { } + @Test + void testGetTableName() { + + String sql = "select 部门,sum (访问次数) from `超音数` where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; + String tableName = SqlSelectHelper.getTableName(sql); + Assert.assertEquals(tableName, "超音数"); + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java index de6965f57..55a11c3a7 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.rest.api; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; @@ -36,6 +37,8 @@ public class SqlQueryApiController { HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); + String sql = querySqlReq.getSql(); + querySqlReq.setSql(StringUtil.replaceBackticks(sql)); chatQueryService.correct(querySqlReq, user); return queryService.queryByReq(querySqlReq, user); } @@ -49,10 +52,11 @@ public class SqlQueryApiController { .stream().map(sql -> { QuerySqlReq querySqlReq = new QuerySqlReq(); BeanUtils.copyProperties(querySqlsReq, querySqlReq); - querySqlReq.setSql(sql); + querySqlReq.setSql(StringUtil.replaceBackticks(sql)); chatQueryService.correct(querySqlReq, user); return querySqlReq; }).collect(Collectors.toList()); return queryService.queryByReqs(semanticQueryReqs, user); } + }