(improvement)(Headless) Remove backticks when querying through SQL. (#968)

This commit is contained in:
lexluo09
2024-05-09 10:47:58 +08:00
committed by GitHub
parent e11aeafbc0
commit b122053e98
4 changed files with 29 additions and 11 deletions

View File

@@ -26,10 +26,9 @@ public class StringUtil {
}
/**
*
* @param v1
* @param v2
* @return value 0 if v1 equal to v2; less than 0 if v1 is less than v2; greater than 0 if v1 is greater than v2
* @return value 0 if v1 equal to v2; less than 0 if v1 is less than v2; greater than 0 if v1 is greater than v2
*/
public static int compareVersion(String v1, String v2) {
String[] v1s = v1.split("\\.");
@@ -46,4 +45,7 @@ public class StringUtil {
return v1s.length - v2s.length;
}
public static String replaceBackticks(String sql) {
return sql.replaceAll("`", "");
}
}

View File

@@ -1,11 +1,6 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.StringUtil;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
@@ -44,6 +39,13 @@ import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Sql Parser Select Helper
*/
@@ -363,7 +365,7 @@ public class SqlSelectHelper {
public static String getTableName(String sql) {
Table table = getTable(sql);
return table.getName();
return StringUtil.replaceBackticks(table.getName());
}
public static List<String> getAggregateFields(String sql) {

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.List;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.statement.select.Select;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import java.util.List;
/**
* SqlParserSelectHelper Test
*/
@@ -272,4 +273,13 @@ class SqlSelectHelperTest {
}
@Test
void testGetTableName() {
String sql = "select 部门,sum (访问次数) from `超音数` where 数据日期 = '2023-08-08'"
+ " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1";
String tableName = SqlSelectHelper.getTableName(sql);
Assert.assertEquals(tableName, "超音数");
}
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.rest.api;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
@@ -36,6 +37,8 @@ public class SqlQueryApiController {
HttpServletRequest request,
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
return queryService.queryByReq(querySqlReq, user);
}
@@ -49,10 +52,11 @@ public class SqlQueryApiController {
.stream().map(sql -> {
QuerySqlReq querySqlReq = new QuerySqlReq();
BeanUtils.copyProperties(querySqlsReq, querySqlReq);
querySqlReq.setSql(sql);
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
return querySqlReq;
}).collect(Collectors.toList());
return queryService.queryByReqs(semanticQueryReqs, user);
}
}