From c36082476f26c214b9108360197fb15c4bf3adb0 Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Mon, 27 Nov 2023 17:02:04 +0800 Subject: [PATCH] (improvement)(chat) improve corrector and support sql union (#427) --- .../chat/corrector/HavingCorrector.java | 24 +- .../chat/corrector/WhereCorrector.java | 2 - .../chat/service/impl/QueryServiceImpl.java | 1 - .../util/jsqlparser/SqlParserAddHelper.java | 104 ++++-- .../jsqlparser/SqlParserRemoveHelper.java | 6 +- .../jsqlparser/SqlParserReplaceHelper.java | 81 ++++- .../jsqlparser/SqlParserSelectHelper.java | 299 ++++++++++-------- .../jsqlparser/SqlParserAddHelperTest.java | 28 +- .../jsqlparser/SqlParserRemoveHelperTest.java | 10 +- .../SqlParserReplaceHelperTest.java | 20 ++ .../jsqlparser/SqlParserSelectHelperTest.java | 4 +- 11 files changed, 366 insertions(+), 213 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java index a8530fc38..018bccc41 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -3,21 +3,20 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; -import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; + +import java.util.Set; +import java.util.List; +import java.util.stream.Collectors; + import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import org.springframework.util.CollectionUtils; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - /** * Perform SQL corrections on the "Having" section in S2SQL. */ @@ -33,8 +32,6 @@ public class HavingCorrector extends BaseSemanticCorrector { //add having expression filed to select addHavingToSelect(semanticParseInfo); - //remove number condition - removeNumberCondition(semanticParseInfo); } private void addHaving(SemanticParseInfo semanticParseInfo) { @@ -57,18 +54,13 @@ public class HavingCorrector extends BaseSemanticCorrector { if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) { return; } - Expression havingExpression = SqlParserSelectHelper.getHavingExpression(correctS2SQL); - if (Objects.nonNull(havingExpression)) { - String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpression); + List havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL); + if (!CollectionUtils.isEmpty(havingExpressionList)) { + String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList); semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql); } return; } - private void removeNumberCondition(SemanticParseInfo semanticParseInfo) { - SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String correctorSql = SqlParserRemoveHelper.removeNumberCondition(sqlInfo.getCorrectS2SQL()); - sqlInfo.setCorrectS2SQL(correctorSql); - } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index 839cbc03b..6c5388685 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -12,7 +12,6 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; @@ -70,7 +69,6 @@ public class WhereCorrector extends BaseSemanticCorrector { private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL); - correctS2SQL = SqlParserRemoveHelper.removeNumberCondition(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index ff9415281..ec8a52e1a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -342,7 +342,6 @@ public class QueryServiceImpl implements QueryService { correctorSql = SqlParserAddHelper.addWhere(correctorSql, addWhereConditions); correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions); log.info("correctorSql after replacing:{}", correctorSql); - correctorSql = SqlParserRemoveHelper.removeNumberCondition(correctorSql); return correctorSql; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelper.java index 9ab04a2fe..b2402b8d1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelper.java @@ -4,6 +4,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.ArrayList; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; @@ -22,7 +23,7 @@ import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; -import net.sf.jsqlparser.util.SelectUtils; +import net.sf.jsqlparser.statement.select.SetOperationList; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -35,37 +36,83 @@ public class SqlParserAddHelper { public static String addFieldsToSelect(String sql, List fields) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); // add fields to select - for (String field : fields) { - SelectUtils.addExpression(selectStatement, new Column(field)); + if (selectStatement == null) { + return null; } + SelectBody selectBody = selectStatement.getSelectBody(); + if (selectBody instanceof PlainSelect) { + PlainSelect plainSelect = (PlainSelect) selectBody; + fields.stream().forEach(field -> { + SelectExpressionItem selectExpressionItem = new SelectExpressionItem(new Column(field)); + plainSelect.addSelectItems(selectExpressionItem); + }); + + } else if (selectBody instanceof SetOperationList) { + SetOperationList setOperationList = (SetOperationList) selectBody; + if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + fields.stream().forEach(field -> { + SelectExpressionItem selectExpressionItem = new SelectExpressionItem(new Column(field)); + subPlainSelect.addSelectItems(selectExpressionItem); + }); + }); + } + } + //for (String field : fields) { + // SelectUtils.addExpression(selectStatement, new Column(field)); + //} return selectStatement.toString(); } - public static String addFunctionToSelect(String sql, Expression expression) { - PlainSelect plainSelect = SqlParserSelectHelper.getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { + public static String addFunctionToSelect(String sql, List expressionList) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + if (selectStatement == null) { + return null; + } + SelectBody selectBody = selectStatement.getSelectBody(); + + List plainSelectList = new ArrayList<>(); + if (selectBody instanceof PlainSelect) { + PlainSelect plainSelect = (PlainSelect) selectBody; + plainSelectList.add(plainSelect); + } else if (selectBody instanceof SetOperationList) { + SetOperationList setOperationList = (SetOperationList) selectBody; + if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + plainSelectList.add(subPlainSelect); + }); + } + } + + if (CollectionUtils.isEmpty(plainSelectList)) { return sql; } - List selectItems = plainSelect.getSelectItems(); - if (CollectionUtils.isEmpty(selectItems)) { - return sql; - } - boolean existFunction = false; - for (SelectItem selectItem : selectItems) { - SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem; - if (expressionItem.getExpression() instanceof Function) { - Function expressionFunction = (Function) expressionItem.getExpression(); - if (expression.toString().equalsIgnoreCase(expressionFunction.toString())) { - existFunction = true; - break; + for (PlainSelect plainSelect : plainSelectList) { + List selectItems = plainSelect.getSelectItems(); + if (CollectionUtils.isEmpty(selectItems)) { + continue; + } + boolean existFunction = false; + for (Expression expression : expressionList) { + for (SelectItem selectItem : selectItems) { + SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem; + if (expressionItem.getExpression() instanceof Function) { + Function expressionFunction = (Function) expressionItem.getExpression(); + if (expression.toString().equalsIgnoreCase(expressionFunction.toString())) { + existFunction = true; + break; + } + } + } + if (!existFunction) { + SelectExpressionItem sumExpressionItem = new SelectExpressionItem(expression); + selectItems.add(sumExpressionItem); } } } - if (!existFunction) { - SelectExpressionItem sumExpressionItem = new SelectExpressionItem(expression); - selectItems.add(sumExpressionItem); - } - return plainSelect.toString(); + return selectStatement.toString(); } public static String addWhere(String sql, String column, Object value) { @@ -182,7 +229,7 @@ public class SqlParserAddHelper { } private static void addAggregateToSelectItems(List selectItems, - Map fieldNameToAggregate) { + Map fieldNameToAggregate) { for (SelectItem selectItem : selectItems) { if (selectItem instanceof SelectExpressionItem) { SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; @@ -197,7 +244,7 @@ public class SqlParserAddHelper { } private static void addAggregateToOrderByItems(List orderByElements, - Map fieldNameToAggregate) { + Map fieldNameToAggregate) { if (orderByElements == null) { return; } @@ -212,7 +259,7 @@ public class SqlParserAddHelper { } private static void addAggregateToGroupByItems(GroupByElement groupByElement, - Map fieldNameToAggregate) { + Map fieldNameToAggregate) { if (groupByElement == null) { return; } @@ -233,7 +280,7 @@ public class SqlParserAddHelper { } private static void modifyWhereExpression(Expression whereExpression, - Map fieldNameToAggregate) { + Map fieldNameToAggregate) { if (SqlParserSelectHelper.isLogicExpression(whereExpression)) { AndExpression andExpression = (AndExpression) whereExpression; Expression leftExpression = andExpression.getLeftExpression(); @@ -296,7 +343,8 @@ public class SqlParserAddHelper { } } } - return selectStatement.toString(); + sql = SqlParserRemoveHelper.removeNumberCondition(selectStatement.toString()); + return sql; } public static String addHaving(String sql, List expressionList) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java index 2e6736c42..989987206 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java @@ -73,7 +73,8 @@ public class SqlParserRemoveHelper { removeWhereCondition(plainSelect.getWhere(), removeFieldNames); } }); - return selectStatement.toString(); + sql = removeNumberCondition(selectStatement.toString()); + return sql; } private static void removeWhereCondition(Expression whereExpression, Set removeFieldNames) { @@ -201,7 +202,8 @@ public class SqlParserRemoveHelper { removeWhereCondition(plainSelect.getHaving(), removeFieldNames); } }); - return selectStatement.toString(); + sql = removeNumberCondition(selectStatement.toString()); + return sql; } public static String removeWhere(String sql, List fields) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java index ed9c05f17..285183c92 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java @@ -30,6 +30,7 @@ import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.statement.select.SubSelect; +import net.sf.jsqlparser.statement.select.SetOperationList; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -82,13 +83,15 @@ public class SqlParserReplaceHelper { } public static String replaceValue(String sql, Map> filedNameToValueMap, - boolean exactReplace) { + boolean exactReplace) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); if (!(selectBody instanceof PlainSelect)) { return sql; } - List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + List plainSelectList = new ArrayList<>(); + plainSelectList.add((PlainSelect) selectBody); + List plainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect plainSelect : plainSelects) { Expression where = plainSelect.getWhere(); FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap); @@ -105,7 +108,9 @@ public class SqlParserReplaceHelper { if (!(selectBody instanceof PlainSelect)) { return sql; } - List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + List plainSelectList = new ArrayList<>(); + plainSelectList.add((PlainSelect) selectBody); + List plainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect plainSelect : plainSelects) { Expression where = plainSelect.getWhere(); FiledNameReplaceVisitor visitor = new FiledNameReplaceVisitor(fieldValueToFieldNames); @@ -122,11 +127,31 @@ public class SqlParserReplaceHelper { public static String replaceFields(String sql, Map fieldNameMap, boolean exactReplace) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); + System.out.println(selectStatement.getSelectBody()); SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { + List plainSelectList = new ArrayList<>(); + if (selectBody instanceof PlainSelect) { + plainSelectList.add((PlainSelect) selectBody); + } else if (selectBody instanceof SetOperationList) { + SetOperationList setOperationList = (SetOperationList) selectBody; + //replace select + if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + plainSelectList.add(subPlainSelect); + }); + } + //replace order by + List orderByElements = setOperationList.getOrderByElements(); + if (!CollectionUtils.isEmpty(orderByElements)) { + for (OrderByElement orderByElement : orderByElements) { + orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace)); + } + } + } else { return sql; } - List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + List plainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect plainSelect : plainSelects) { replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, plainSelect); } @@ -134,7 +159,7 @@ public class SqlParserReplaceHelper { } private static void replaceFieldsInPlainOneSelect(Map fieldNameMap, boolean exactReplace, - PlainSelect plainSelect) { + PlainSelect plainSelect) { //1. replace where fields Expression where = plainSelect.getWhere(); FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldNameMap, exactReplace); @@ -170,7 +195,9 @@ public class SqlParserReplaceHelper { for (Join join : joins) { join.getOnExpression().accept(visitor); SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody(); - List subPlainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) subSelectBody); + List plainSelectList = new ArrayList<>(); + plainSelectList.add((PlainSelect) subSelectBody); + List subPlainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect subPlainSelect : subPlainSelects) { replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect); } @@ -199,7 +226,9 @@ public class SqlParserReplaceHelper { if (!(selectBody instanceof PlainSelect)) { return sql; } - List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + List plainSelectList = new ArrayList<>(); + plainSelectList.add((PlainSelect) selectBody); + List plainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect plainSelect : plainSelects) { replaceFunction(functionMap, plainSelect); } @@ -238,7 +267,9 @@ public class SqlParserReplaceHelper { if (!(selectBody instanceof PlainSelect)) { return sql; } - List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + List plainSelectList = new ArrayList<>(); + plainSelectList.add((PlainSelect) selectBody); + List plainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect plainSelect : plainSelects) { replaceFunction(plainSelect); } @@ -295,7 +326,7 @@ public class SqlParserReplaceHelper { } private static void replaceOrderByFunction(Map functionMap, - List orderByElementList) { + List orderByElementList) { if (Objects.isNull(orderByElementList)) { return; } @@ -321,7 +352,7 @@ public class SqlParserReplaceHelper { } private static void addWaitingExpression(PlainSelect plainSelect, Expression where, - List waitingForAdds) { + List waitingForAdds) { if (CollectionUtils.isEmpty(waitingForAdds)) { return; } @@ -341,9 +372,27 @@ public class SqlParserReplaceHelper { } Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); - PlainSelect plainSelect = (PlainSelect) selectBody; + if (selectBody instanceof PlainSelect) { + PlainSelect plainSelect = (PlainSelect) selectBody; + replaceSingleTable(plainSelect, tableName); + } else if (selectBody instanceof SetOperationList) { + SetOperationList setOperationList = (SetOperationList) selectBody; + if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + replaceSingleTable(subPlainSelect, tableName); + }); + } + } + + return selectStatement.toString(); + } + + public static void replaceSingleTable(PlainSelect plainSelect, String tableName) { // replace table name - List painSelects = SqlParserSelectHelper.getPlainSelects(plainSelect); + List plainSelects = new ArrayList<>(); + plainSelects.add(plainSelect); + List painSelects = SqlParserSelectHelper.getPlainSelects(plainSelects); for (PlainSelect painSelect : painSelects) { painSelect.accept( new SelectVisitorAdapter() { @@ -356,15 +405,15 @@ public class SqlParserReplaceHelper { if (!CollectionUtils.isEmpty(joins)) { for (Join join : joins) { SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody(); - List subPlainSelects = SqlParserSelectHelper.getPlainSelects( - (PlainSelect) subSelectBody); + List plainSelectList = new ArrayList<>(); + plainSelectList.add((PlainSelect) subSelectBody); + List subPlainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect subPlainSelect : subPlainSelects) { subPlainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); } } } } - return selectStatement.toString(); } public static String replaceAlias(String sql) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index 28c229c71..b1781a93f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -1,5 +1,12 @@ package com.tencent.supersonic.common.util.jsqlparser; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; @@ -26,16 +33,10 @@ import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.statement.select.SubSelect; +import net.sf.jsqlparser.statement.select.SetOperationList; import org.apache.commons.lang3.tuple.ImmutablePair; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - /** * Sql Parser Select Helper */ @@ -43,68 +44,84 @@ import java.util.stream.Collectors; public class SqlParserSelectHelper { public static List getFilterExpression(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return new ArrayList<>(); - } + List plainSelectList = getPlainSelect(sql); Set result = new HashSet<>(); - Expression where = plainSelect.getWhere(); - if (Objects.nonNull(where)) { - where.accept(new FieldAndValueAcquireVisitor(result)); - } - Expression having = plainSelect.getHaving(); - if (Objects.nonNull(having)) { - having.accept(new FieldAndValueAcquireVisitor(result)); + for (PlainSelect plainSelect : plainSelectList) { + if (Objects.isNull(plainSelect)) { + continue; + } + Expression where = plainSelect.getWhere(); + if (Objects.nonNull(where)) { + where.accept(new FieldAndValueAcquireVisitor(result)); + } + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + having.accept(new FieldAndValueAcquireVisitor(result)); + } } return new ArrayList<>(result); } public static List getWhereFields(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { + List plainSelectList = getPlainSelect(sql); + if (CollectionUtils.isEmpty(plainSelectList)) { return new ArrayList<>(); } Set result = new HashSet<>(); - getWhereFields(plainSelect, result); + getWhereFields(plainSelectList, result); return new ArrayList<>(result); } - private static void getWhereFields(PlainSelect plainSelect, Set result) { - Expression where = plainSelect.getWhere(); - if (Objects.nonNull(where)) { - where.accept(new FieldAcquireVisitor(result)); - } + private static void getWhereFields(List plainSelectList, Set result) { + plainSelectList.stream().forEach(plainSelect -> { + Expression where = plainSelect.getWhere(); + if (Objects.nonNull(where)) { + where.accept(new FieldAcquireVisitor(result)); + } + }); } public static List getSelectFields(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { + List plainSelectList = getPlainSelect(sql); + if (CollectionUtils.isEmpty(plainSelectList)) { return new ArrayList<>(); } - return new ArrayList<>(getSelectFields(plainSelect)); + return new ArrayList<>(getSelectFields(plainSelectList)); } - public static Set getSelectFields(PlainSelect plainSelect) { - List selectItems = plainSelect.getSelectItems(); + public static Set getSelectFields(List plainSelectList) { Set result = new HashSet<>(); - for (SelectItem selectItem : selectItems) { - selectItem.accept(new FieldAcquireVisitor(result)); - } + plainSelectList.stream().forEach(plainSelect -> { + List selectItems = plainSelect.getSelectItems(); + for (SelectItem selectItem : selectItems) { + selectItem.accept(new FieldAcquireVisitor(result)); + } + }); return result; } - public static PlainSelect getPlainSelect(String sql) { + public static List getPlainSelect(String sql) { Select selectStatement = getSelect(sql); if (selectStatement == null) { return null; } SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { - return null; + List plainSelectList = new ArrayList<>(); + if (selectBody instanceof PlainSelect) { + PlainSelect plainSelect = (PlainSelect) selectBody; + plainSelectList.add(plainSelect); + } else if (selectBody instanceof SetOperationList) { + SetOperationList setOperationList = (SetOperationList) selectBody; + if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + plainSelectList.add(subPlainSelect); + }); + } } - return (PlainSelect) selectBody; + return plainSelectList; } public static Select getSelect(String sql) { @@ -122,39 +139,40 @@ public class SqlParserSelectHelper { return (Select) statement; } - public static List getPlainSelects(PlainSelect plainSelect) { + public static List getPlainSelects(List plainSelectList) { List plainSelects = new ArrayList<>(); - plainSelects.add(plainSelect); - - ExpressionVisitorAdapter expressionVisitor = new ExpressionVisitorAdapter() { - @Override - public void visit(SubSelect subSelect) { - SelectBody subSelectBody = subSelect.getSelectBody(); - if (subSelectBody instanceof PlainSelect) { - plainSelects.add((PlainSelect) subSelectBody); - } - } - }; - - plainSelect.accept(new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - Expression whereExpression = plainSelect.getWhere(); - if (whereExpression != null) { - whereExpression.accept(expressionVisitor); - } - Expression having = plainSelect.getHaving(); - if (Objects.nonNull(having)) { - having.accept(expressionVisitor); - } - List selectItems = plainSelect.getSelectItems(); - if (!CollectionUtils.isEmpty(selectItems)) { - for (SelectItem selectItem : selectItems) { - selectItem.accept(expressionVisitor); + for (PlainSelect plainSelect : plainSelectList) { + plainSelects.add(plainSelect); + ExpressionVisitorAdapter expressionVisitor = new ExpressionVisitorAdapter() { + @Override + public void visit(SubSelect subSelect) { + SelectBody subSelectBody = subSelect.getSelectBody(); + if (subSelectBody instanceof PlainSelect) { + plainSelects.add((PlainSelect) subSelectBody); } } - } - }); + }; + + plainSelect.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + Expression whereExpression = plainSelect.getWhere(); + if (whereExpression != null) { + whereExpression.accept(expressionVisitor); + } + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + having.accept(expressionVisitor); + } + List selectItems = plainSelect.getSelectItems(); + if (!CollectionUtils.isEmpty(selectItems)) { + for (SelectItem selectItem : selectItems) { + selectItem.accept(expressionVisitor); + } + } + } + }); + } return plainSelects; } @@ -172,13 +190,15 @@ public class SqlParserSelectHelper { if (Objects.isNull(plainSelect)) { return new ArrayList<>(); } - Set result = getSelectFields(plainSelect); + List plainSelectList = new ArrayList<>(); + plainSelectList.add(plainSelect); + Set result = getSelectFields(plainSelectList); getGroupByFields(plainSelect, result); getOrderByFields(plainSelect, result); - getWhereFields(plainSelect, result); + getWhereFields(plainSelectList, result); getHavingFields(plainSelect, result); @@ -193,56 +213,65 @@ public class SqlParserSelectHelper { } - public static Expression getHavingExpression(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - Expression having = plainSelect.getHaving(); - if (Objects.nonNull(having)) { - if (!(having instanceof ComparisonOperator)) { - return null; - } - ComparisonOperator comparisonOperator = (ComparisonOperator) having; - if (comparisonOperator.getLeftExpression() instanceof Function) { - return comparisonOperator.getLeftExpression(); - } else if (comparisonOperator.getRightExpression() instanceof Function) { - return comparisonOperator.getRightExpression(); + public static List getHavingExpression(String sql) { + List plainSelectList = getPlainSelect(sql); + List expressionList = new ArrayList<>(); + for (PlainSelect plainSelect : plainSelectList) { + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + if (!(having instanceof ComparisonOperator)) { + continue; + } + ComparisonOperator comparisonOperator = (ComparisonOperator) having; + if (comparisonOperator.getLeftExpression() instanceof Function) { + expressionList.add(comparisonOperator.getLeftExpression()); + } else if (comparisonOperator.getRightExpression() instanceof Function) { + expressionList.add(comparisonOperator.getRightExpression()); + } } } - return null; + return expressionList; } public static List getWhereExpressions(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return new ArrayList<>(); - } + List plainSelectList = getPlainSelect(sql); Set result = new HashSet<>(); - Expression where = plainSelect.getWhere(); - if (Objects.nonNull(where)) { - where.accept(new FieldAndValueAcquireVisitor(result)); + for (PlainSelect plainSelect : plainSelectList) { + if (Objects.isNull(plainSelect)) { + continue; + } + Expression where = plainSelect.getWhere(); + if (Objects.nonNull(where)) { + where.accept(new FieldAndValueAcquireVisitor(result)); + } } return new ArrayList<>(result); } public static List getHavingExpressions(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return new ArrayList<>(); - } + List plainSelectList = getPlainSelect(sql); Set result = new HashSet<>(); - Expression having = plainSelect.getHaving(); - if (Objects.nonNull(having)) { - having.accept(new FieldAndValueAcquireVisitor(result)); + for (PlainSelect plainSelect : plainSelectList) { + if (Objects.isNull(plainSelect)) { + continue; + } + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + having.accept(new FieldAndValueAcquireVisitor(result)); + } } return new ArrayList<>(result); } public static List getOrderByFields(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return new ArrayList<>(); - } + List plainSelectList = getPlainSelect(sql); Set result = new HashSet<>(); - getOrderByFields(plainSelect, result); + for (PlainSelect plainSelect : plainSelectList) { + if (Objects.isNull(plainSelect)) { + continue; + } + getOrderByFields(plainSelect, result); + } return new ArrayList<>(result); } @@ -266,20 +295,26 @@ public class SqlParserSelectHelper { } public static List getOrderByExpressions(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return new ArrayList<>(); + List plainSelectList = getPlainSelect(sql); + HashSet result = new HashSet<>(); + for (PlainSelect plainSelect : plainSelectList) { + if (Objects.isNull(plainSelect)) { + return new ArrayList<>(); + } + result.addAll(getOrderByFields(plainSelect)); } - return new ArrayList<>(getOrderByFields(plainSelect)); + return new ArrayList<>(result); } public static List getGroupByFields(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return new ArrayList<>(); - } + List plainSelectList = getPlainSelect(sql); HashSet result = new HashSet<>(); - getGroupByFields(plainSelect, result); + for (PlainSelect plainSelect : plainSelectList) { + if (Objects.isNull(plainSelect)) { + continue; + } + getGroupByFields(plainSelect, result); + } return new ArrayList<>(result); } @@ -302,21 +337,23 @@ public class SqlParserSelectHelper { } public static List getAggregateFields(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return new ArrayList<>(); - } + List plainSelectList = getPlainSelect(sql); Set result = new HashSet<>(); - List selectItems = plainSelect.getSelectItems(); - for (SelectItem selectItem : selectItems) { - if (selectItem instanceof SelectExpressionItem) { - SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem; - if (expressionItem.getExpression() instanceof Function) { - Function function = (Function) expressionItem.getExpression(); - if (Objects.nonNull(function.getParameters()) - && !CollectionUtils.isEmpty(function.getParameters().getExpressions())) { - String columnName = function.getParameters().getExpressions().get(0).toString(); - result.add(columnName); + for (PlainSelect plainSelect : plainSelectList) { + if (Objects.isNull(plainSelect)) { + continue; + } + List selectItems = plainSelect.getSelectItems(); + for (SelectItem selectItem : selectItems) { + if (selectItem instanceof SelectExpressionItem) { + SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem; + if (expressionItem.getExpression() instanceof Function) { + Function function = (Function) expressionItem.getExpression(); + if (Objects.nonNull(function.getParameters()) + && !CollectionUtils.isEmpty(function.getParameters().getExpressions())) { + String columnName = function.getParameters().getExpressions().get(0).toString(); + result.add(columnName); + } } } } @@ -415,8 +452,16 @@ public class SqlParserSelectHelper { return null; } SelectBody selectBody = selectStatement.getSelectBody(); - PlainSelect plainSelect = (PlainSelect) selectBody; - return (Table) plainSelect.getFromItem(); + if (selectBody instanceof PlainSelect) { + PlainSelect plainSelect = (PlainSelect) selectBody; + return (Table) plainSelect.getFromItem(); + } else if (selectBody instanceof SetOperationList) { + SetOperationList setOperationList = (SetOperationList) selectBody; + if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { + return (Table) ((PlainSelect) setOperationList.getSelects().get(0)).getFromItem(); + } + } + return null; } public static String getDbTableName(String sql) { diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java index defb1fe6a..dbb4553a1 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java @@ -54,9 +54,9 @@ class SqlParserAddHelperTest { void addFunctionToSelect() { String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + List havingExpressionList = SqlParserSelectHelper.getHavingExpression(sql); - String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", @@ -64,9 +64,9 @@ class SqlParserAddHelperTest { sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + havingExpressionList = SqlParserSelectHelper.getHavingExpression(sql); - replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", @@ -74,9 +74,9 @@ class SqlParserAddHelperTest { sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + havingExpressionList = SqlParserSelectHelper.getHavingExpression(sql); - replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') " + "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", @@ -88,9 +88,9 @@ class SqlParserAddHelperTest { void addAggregateToField() { String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + List havingExpressionList = SqlParserSelectHelper.getHavingExpression(sql); - String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", @@ -98,9 +98,9 @@ class SqlParserAddHelperTest { sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + havingExpressionList = SqlParserSelectHelper.getHavingExpression(sql); - replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", @@ -108,9 +108,9 @@ class SqlParserAddHelperTest { sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + havingExpressionList = SqlParserSelectHelper.getHavingExpression(sql); - replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') " + "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", @@ -360,7 +360,7 @@ class SqlParserAddHelperTest { String replaceSql = SqlParserAddHelper.addHaving(sql, fieldNames); Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 " + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", replaceSql); @@ -370,7 +370,7 @@ class SqlParserAddHelperTest { replaceSql = SqlParserAddHelper.addHaving(sql, fieldNames); Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE (2 > 1) AND sys_imp_date = '2023-09-11' " + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", replaceSql); } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java index dcb456319..0ef454a20 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java @@ -56,7 +56,7 @@ class SqlParserRemoveHelperTest { removeFieldNames.add("播放量"); String replaceSql = SqlParserRemoveHelper.removeHavingCondition(sql, removeFieldNames); Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE 歌手名 = '周杰伦' HAVING 2 > 1", + "SELECT 歌曲名 FROM 歌曲库 WHERE 歌手名 = '周杰伦'", replaceSql); } @@ -74,7 +74,7 @@ class SqlParserRemoveHelperTest { Assert.assertEquals( "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + + "AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); @@ -84,7 +84,7 @@ class SqlParserRemoveHelperTest { replaceSql = SqlParserRemoveHelper.removeWhereCondition(sql, removeFieldNames); Assert.assertEquals( "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND " + + "AND 数据日期 = '2023-08-09' AND " + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); @@ -93,8 +93,8 @@ class SqlParserRemoveHelperTest { + " order by 播放量 desc limit 11"; replaceSql = SqlParserRemoveHelper.removeWhereCondition(sql, removeFieldNames); Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 1 IN (1) AND 1 IN (1)) AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1) " + + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java index c0acc55bb..719598fb7 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java @@ -159,6 +159,20 @@ class SqlParserReplaceHelperTest { } + @Test + void replaceUnionFields() { + Map fieldToBizName1 = new HashMap<>(); + fieldToBizName1.put("公司成立时间", "company_established_time"); + fieldToBizName1.put("年营业额", "annual_turnover"); + String replaceSql = "SELECT * FROM 互联网企业 ORDER BY 公司成立时间 DESC LIMIT 3 " + + "UNION SELECT * FROM 互联网企业 ORDER BY 年营业额 DESC LIMIT 5"; + replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName1); + replaceSql = SqlParserReplaceHelper.replaceTable(replaceSql, "internet"); + Assert.assertEquals( + "SELECT * FROM internet ORDER BY company_established_time DESC LIMIT 3 " + + "UNION SELECT * FROM internet ORDER BY annual_turnover DESC LIMIT 5", replaceSql); + } + @Test void replaceFields() { @@ -348,6 +362,12 @@ class SqlParserReplaceHelperTest { "SELECT 部门, sum(访问次数) FROM s2 WHERE 数据日期 = '2023-08-08' " + "AND 用户 = alice AND 发布日期 = '11' GROUP BY 部门 LIMIT 1", replaceSql); + sql = "select * from 互联网企业 order by 公司成立时间 desc limit 3 union select * from 互联网企业 order by 年营业额 desc limit 5"; + replaceSql = SqlParserReplaceHelper.replaceTable(sql, "internet"); + Assert.assertEquals( + "SELECT * FROM internet ORDER BY 公司成立时间 DESC LIMIT 3 " + + "UNION SELECT * FROM internet ORDER BY 年营业额 DESC LIMIT 5", replaceSql); + sql = "SELECT * FROM CSpider音乐 WHERE (评分 < (SELECT min(评分) " + "FROM CSpider音乐 WHERE 语种 = '英文')) AND 数据日期 = '2023-10-11'"; replaceSql = SqlParserReplaceHelper.replaceTable(sql, "cspider"); diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index c1cb5bbec..54392b2ec 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -246,9 +246,9 @@ class SqlParserSelectHelperTest { String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - Expression leftExpression = SqlParserSelectHelper.getHavingExpression(sql); + List leftExpressionList = SqlParserSelectHelper.getHavingExpression(sql); - Assert.assertEquals(leftExpression.toString(), "sum(pv)"); + Assert.assertEquals(leftExpressionList.get(0).toString(), "sum(pv)"); }