diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index 6dddeb2c2..455643938 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -7,8 +7,8 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.ArrayList; import java.util.HashSet; @@ -53,7 +53,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); whereFields.removeAll(selectFields); whereFields.remove(DateUtils.DATE_FIELD); - String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); + String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); semanticCorrectInfo.setSql(replaceFields); } @@ -76,7 +76,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return; } - String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate); + String aggregateSql = SqlParserAddHelper.addAggregateToField(sql, metricToAggregate); semanticCorrectInfo.setSql(aggregateSql); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java index 5d73483b6..5d6d6c51d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java @@ -1,8 +1,9 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import java.util.Objects; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; @@ -15,12 +16,12 @@ public class GlobalAfterCorrector extends BaseSemanticCorrector { super.correct(semanticCorrectInfo); String sql = semanticCorrectInfo.getSql(); - if (!SqlParserSelectHelper.hasAggregateFunction(sql)) { + if (!SqlParserSelectFunctionHelper.hasAggregateFunction(sql)) { return; } Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); if (Objects.nonNull(havingExpression)) { - String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); semanticCorrectInfo.setSql(replaceSql); } return; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java index e74a46550..9217888cf 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import java.util.List; import java.util.Map; import java.util.Objects; @@ -33,7 +33,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { } private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { - String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql()); + String replaceAlias = SqlParserReplaceHelper.replaceAlias(semanticCorrectInfo.getSql()); semanticCorrectInfo.setSql(replaceAlias); } @@ -41,7 +41,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { Map fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId()); - String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap); + String sql = SqlParserReplaceHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap); semanticCorrectInfo.setSql(sql); } @@ -56,7 +56,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { Collectors.groupingBy(ElementValue::getFieldValue, Collectors.mapping(ElementValue::getFieldName, Collectors.toSet()))); - String sql = SqlParserUpdateHelper.replaceFieldNameByValue(semanticCorrectInfo.getSql(), + String sql = SqlParserReplaceHelper.replaceFieldNameByValue(semanticCorrectInfo.getSql(), fieldValueToFieldNames); semanticCorrectInfo.setSql(sql); } @@ -90,7 +90,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { (existingValue, newValue) -> newValue) ))); - String sql = SqlParserUpdateHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false); + String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false); semanticCorrectInfo.setSql(sql); } } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java index 87c4328de..684bbfc97 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -4,8 +4,8 @@ import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.List; import java.util.Set; @@ -57,6 +57,6 @@ public class GroupByCorrector extends BaseSemanticCorrector { return true; }) .collect(Collectors.toSet()); - semanticCorrectInfo.setSql(SqlParserUpdateHelper.addGroupBy(sql, groupByFields)); + semanticCorrectInfo.setSql(SqlParserAddHelper.addGroupBy(sql, groupByFields)); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java index 736a056f6..badb1d328 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.Set; import java.util.stream.Collectors; @@ -30,7 +30,7 @@ public class HavingCorrector extends BaseSemanticCorrector { if (CollectionUtils.isEmpty(metrics)) { return; } - String havingSql = SqlParserUpdateHelper.addHaving(semanticCorrectInfo.getSql(), metrics); + String havingSql = SqlParserAddHelper.addHaving(semanticCorrectInfo.getSql(), metrics); semanticCorrectInfo.setSql(havingSql); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index 0a82eec2c..5ef1bb5de 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -10,8 +10,9 @@ import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.StringUtil; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.HashMap; import java.util.List; @@ -56,14 +57,14 @@ public class WhereCorrector extends BaseSemanticCorrector { } catch (JSQLParserException e) { log.error("parseCondExpression", e); } - String sql = SqlParserUpdateHelper.addWhere(preSql, expression); + String sql = SqlParserAddHelper.addWhere(preSql, expression); semanticCorrectInfo.setSql(sql); } } private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) { String sql = semanticCorrectInfo.getSql(); - sql = SqlParserUpdateHelper.replaceFunction(sql); + sql = SqlParserReplaceHelper.replaceFunction(sql); semanticCorrectInfo.setSql(sql); } @@ -72,8 +73,8 @@ public class WhereCorrector extends BaseSemanticCorrector { List whereFields = SqlParserSelectHelper.getWhereFields(sql); if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) { String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); - sql = SqlParserUpdateHelper.addParenthesisToWhere(sql); - sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate); + sql = SqlParserAddHelper.addParenthesisToWhere(sql); + sql = SqlParserAddHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate); } semanticCorrectInfo.setSql(sql); } @@ -104,7 +105,7 @@ public class WhereCorrector extends BaseSemanticCorrector { } Map> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions); - String sql = SqlParserUpdateHelper.replaceValue(semanticCorrectInfo.getSql(), aliasAndBizNameToTechName); + String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), aliasAndBizNameToTechName); semanticCorrectInfo.setSql(sql); return; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java index 893fd9c2f..85c8b6c91 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java @@ -32,6 +32,7 @@ import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; @@ -60,8 +61,6 @@ import org.springframework.web.client.RestTemplate; @Slf4j public class LLMDslParser implements SemanticParser { - public static final double function_bonus_threshold = 201; - @Override public void parse(QueryContext queryCtx, ChatContext chatCtx) { QueryReq request = queryCtx.getRequest(); @@ -159,7 +158,7 @@ public class LLMDslParser implements SemanticParser { Set metrics = getElements(modelId, allFields, semanticSchema.getMetrics()); parseInfo.setMetrics(metrics); - if (SqlParserSelectHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) { + if (SqlParserSelectFunctionHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) { parseInfo.setNativeQuery(false); List groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql()); List groupByDimensions = getFieldsExceptDate(groupByFields); @@ -269,7 +268,7 @@ public class LLMDslParser implements SemanticParser { properties.put("name", dslTool.getName()); parseInfo.setProperties(properties); - parseInfo.setScore(function_bonus_threshold); + parseInfo.setScore(queryCtx.getRequest().getQueryText().length()); parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.getSqlInfo().setLlmParseSql(dslParseResult.getLlmResp().getSqlOutput()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java index c04ba12f4..dc9f570b8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java @@ -10,11 +10,9 @@ import com.tencent.supersonic.chat.query.llm.dsl.DslQuery; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; import java.util.List; - -import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; -@Slf4j + public class EntityInfoParseResponder implements ParseResponder { @Override @@ -37,8 +35,6 @@ public class EntityInfoParseResponder implements ParseResponder { parseInfo.setEntityInfo(entityInfo); } //2. set native value - entityInfo = semanticService.getEntityInfo(parseInfo.getModelId()); - log.info("entityInfo:{}", entityInfo); String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo); if (StringUtils.isNotEmpty(primaryEntityBizName)) { //if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 6fce496e8..29fc2754f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -40,10 +40,11 @@ import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.service.SearchService; import com.tencent.supersonic.semantic.api.model.response.ExplainResp; @@ -59,8 +60,6 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; - - import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.commons.collections.CollectionUtils; @@ -155,7 +154,7 @@ public class QueryServiceImpl implements QueryService { } private List getTop5CandidateParseInfo(List selectedParses, - List candidateParses) { + List candidateParses) { if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) { return candidateParses; } @@ -318,9 +317,9 @@ public class QueryServiceImpl implements QueryService { updateDateInfo(queryData, parseInfo, filedNameToValueMap, filterExpressionList); log.info("filedNameToValueMap:{}", filedNameToValueMap); - correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap); + correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); log.info("havingFiledNameToValueMap:{}", havingFiledNameToValueMap); - correctorSql = SqlParserUpdateHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); + correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); log.info("correctorSql after replacing:{}", correctorSql); llmResp.setCorrectorSql(correctorSql); dslParseResult.setLlmResp(llmResp); @@ -355,12 +354,10 @@ public class QueryServiceImpl implements QueryService { return; } Map map = new HashMap<>(); - //List dateFields = new ArrayList<>(QueryStructUtils.internalTimeCols); - String dateField = "数据日期"; + String dateField = DateUtils.DATE_FIELD; if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) { for (FilterExpression filterExpression : filterExpressionList) { - if (filterExpression.getFieldName() != null - && filterExpression.getFieldName().equals("数据日期")) { + if (DateUtils.DATE_FIELD.equals(filterExpression.getFieldName())) { dateField = filterExpression.getFieldName(); map.put(filterExpression.getFieldValue().toString(), queryData.getDateInfo().getStartDate()); @@ -369,7 +366,7 @@ public class QueryServiceImpl implements QueryService { } } else { for (FilterExpression filterExpression : filterExpressionList) { - if (filterExpression.getFieldName().equals("数据日期")) { + if (DateUtils.DATE_FIELD.equals(filterExpression.getFieldName())) { dateField = filterExpression.getFieldName(); if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator()) || FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java index 74c3c5454..dddebd01b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java @@ -1,23 +1,20 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.HashMap; import java.util.Map; import java.util.Objects; - import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.DoubleValue; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; +import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.LongValue; -import net.sf.jsqlparser.expression.DoubleValue; import net.sf.jsqlparser.expression.StringValue; -import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; -import net.sf.jsqlparser.expression.operators.relational.GreaterThan; -import net.sf.jsqlparser.expression.operators.relational.MinorThan; -import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; - +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.expression.operators.relational.MinorThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.schema.Column; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -29,7 +26,6 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { private boolean exactReplace; private Map> filedNameToValueMap; - public FieldlValueReplaceVisitor(boolean exactReplace, Map> filedNameToValueMap) { this.exactReplace = exactReplace; this.filedNameToValueMap = filedNameToValueMap; @@ -68,27 +64,11 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { return; } - String columnName = ""; - if (leftExpression instanceof Column) { - Column leftColumnName = (Column) leftExpression; - columnName = leftColumnName.getColumnName(); - } - if (leftExpression instanceof Function) { - Function function = (Function) leftExpression; - columnName = ((Column) function.getParameters().getExpressions().get(0)).getColumnName(); - } + String columnName = SqlParserSelectHelper.getColumnName(leftExpression); if (StringUtils.isEmpty(columnName)) { return; } - - Map valueMap = new HashMap<>(); - for (String key : filedNameToValueMap.keySet()) { - if (columnName.contains(key)) { - valueMap = filedNameToValueMap.get(key); - break; - } - } - //filedNameToValueMap.get(columnName); + Map valueMap = filedNameToValueMap.get(columnName); if (Objects.isNull(valueMap) || valueMap.isEmpty()) { return; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelper.java new file mode 100644 index 000000000..c2fb80b1b --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelper.java @@ -0,0 +1,293 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.Parenthesis; +import net.sf.jsqlparser.expression.StringValue; +import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.statement.select.GroupByElement; +import net.sf.jsqlparser.statement.select.OrderByElement; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SelectExpressionItem; +import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; +import net.sf.jsqlparser.util.SelectUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; + +/** + * Sql Parser add Helper + */ +@Slf4j +public class SqlParserAddHelper { + + public static String addFieldsToSelect(String sql, List fields) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + // add fields to select + for (String field : fields) { + SelectUtils.addExpression(selectStatement, new Column(field)); + } + return selectStatement.toString(); + } + + public static String addFunctionToSelect(String sql, Expression expression) { + PlainSelect plainSelect = SqlParserSelectHelper.getPlainSelect(sql); + if (Objects.isNull(plainSelect)) { + return sql; + } + List selectItems = plainSelect.getSelectItems(); + if (CollectionUtils.isEmpty(selectItems)) { + return sql; + } + boolean existFunction = false; + for (SelectItem selectItem : selectItems) { + SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem; + if (expressionItem.getExpression() instanceof Function) { + Function expressionFunction = (Function) expressionItem.getExpression(); + if (expression.toString().equalsIgnoreCase(expressionFunction.toString())) { + existFunction = true; + break; + } + } + } + if (!existFunction) { + SelectExpressionItem sumExpressionItem = new SelectExpressionItem(expression); + selectItems.add(sumExpressionItem); + } + return plainSelect.toString(); + } + + public static String addWhere(String sql, String column, Object value) { + if (StringUtils.isEmpty(column) || Objects.isNull(value)) { + return sql; + } + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + Expression where = plainSelect.getWhere(); + + Expression right = new StringValue(value.toString()); + if (value instanceof Integer || value instanceof Long) { + right = new LongValue(value.toString()); + } + + if (where == null) { + plainSelect.setWhere(new EqualsTo(new Column(column), right)); + } else { + plainSelect.setWhere(new AndExpression(where, new EqualsTo(new Column(column), right))); + } + return selectStatement.toString(); + } + + + public static String addWhere(String sql, Expression expression) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + Expression where = plainSelect.getWhere(); + + if (where == null) { + plainSelect.setWhere(expression); + } else { + plainSelect.setWhere(new AndExpression(where, expression)); + } + return selectStatement.toString(); + } + + public static String addAggregateToField(String sql, Map fieldNameToAggregate) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + selectBody.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate); + addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate); + addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate); + addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate); + } + }); + return selectStatement.toString(); + } + + public static String addGroupBy(String sql, Set groupByFields) { + if (CollectionUtils.isEmpty(groupByFields)) { + return sql; + } + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + + PlainSelect plainSelect = (PlainSelect) selectBody; + GroupByElement groupByElement = new GroupByElement(); + List originalGroupByFields = SqlParserSelectHelper.getGroupByFields(sql); + if (!CollectionUtils.isEmpty(originalGroupByFields)) { + groupByFields.addAll(originalGroupByFields); + } + for (String groupByField : groupByFields) { + groupByElement.addGroupByExpression(new Column(groupByField)); + } + plainSelect.setGroupByElement(groupByElement); + return selectStatement.toString(); + } + + private static void addAggregateToSelectItems(List selectItems, + Map fieldNameToAggregate) { + for (SelectItem selectItem : selectItems) { + if (selectItem instanceof SelectExpressionItem) { + SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; + Expression expression = selectExpressionItem.getExpression(); + Function function = SqlParserSelectFunctionHelper.getFunction(expression, fieldNameToAggregate); + if (function == null) { + continue; + } + selectExpressionItem.setExpression(function); + } + } + } + + private static void addAggregateToOrderByItems(List orderByElements, + Map fieldNameToAggregate) { + if (orderByElements == null) { + return; + } + for (OrderByElement orderByElement : orderByElements) { + Expression expression = orderByElement.getExpression(); + Function function = SqlParserSelectFunctionHelper.getFunction(expression, fieldNameToAggregate); + if (function == null) { + continue; + } + orderByElement.setExpression(function); + } + } + + private static void addAggregateToGroupByItems(GroupByElement groupByElement, + Map fieldNameToAggregate) { + if (groupByElement == null) { + return; + } + for (Expression expression : groupByElement.getGroupByExpressions()) { + Function function = SqlParserSelectFunctionHelper.getFunction(expression, fieldNameToAggregate); + if (function == null) { + continue; + } + groupByElement.addGroupByExpression(function); + } + } + + private static void addAggregateToWhereItems(Expression whereExpression, Map fieldNameToAggregate) { + if (whereExpression == null) { + return; + } + modifyWhereExpression(whereExpression, fieldNameToAggregate); + } + + private static void modifyWhereExpression(Expression whereExpression, + Map fieldNameToAggregate) { + if (SqlParserSelectHelper.isLogicExpression(whereExpression)) { + AndExpression andExpression = (AndExpression) whereExpression; + Expression leftExpression = andExpression.getLeftExpression(); + Expression rightExpression = andExpression.getRightExpression(); + modifyWhereExpression(leftExpression, fieldNameToAggregate); + modifyWhereExpression(rightExpression, fieldNameToAggregate); + } else if (whereExpression instanceof Parenthesis) { + modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate); + } else { + setAggToFunction(whereExpression, fieldNameToAggregate); + } + } + + private static void setAggToFunction(Expression expression, Map fieldNameToAggregate) { + if (!(expression instanceof ComparisonOperator)) { + return; + } + ComparisonOperator comparisonOperator = (ComparisonOperator) expression; + if (comparisonOperator.getRightExpression() instanceof Column) { + String columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName(); + Function function = SqlParserSelectFunctionHelper.getFunction(comparisonOperator.getRightExpression(), + fieldNameToAggregate.get(columnName)); + if (Objects.nonNull(function)) { + comparisonOperator.setRightExpression(function); + } + } + if (comparisonOperator.getLeftExpression() instanceof Column) { + String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName(); + Function function = SqlParserSelectFunctionHelper.getFunction(comparisonOperator.getLeftExpression(), + fieldNameToAggregate.get(columnName)); + if (Objects.nonNull(function)) { + comparisonOperator.setLeftExpression(function); + } + } + } + + public static String addHaving(String sql, Set fieldNames) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + + PlainSelect plainSelect = (PlainSelect) selectBody; + //replace metric to 1 and 1 and add having metric + Expression where = plainSelect.getWhere(); + FiledFilterReplaceVisitor visitor = new FiledFilterReplaceVisitor(fieldNames); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + List waitingForAdds = visitor.getWaitingForAdds(); + if (!CollectionUtils.isEmpty(waitingForAdds)) { + for (Expression waitingForAdd : waitingForAdds) { + Expression having = plainSelect.getHaving(); + if (Objects.isNull(having)) { + plainSelect.setHaving(waitingForAdd); + } else { + plainSelect.setHaving(new AndExpression(having, waitingForAdd)); + } + } + } + return selectStatement.toString(); + } + + public static String addParenthesisToWhere(String sql) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + Expression where = plainSelect.getWhere(); + if (Objects.nonNull(where)) { + Parenthesis parenthesis = new Parenthesis(where); + plainSelect.setWhere(parenthesis); + } + return selectStatement.toString(); + } +} + diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java new file mode 100644 index 000000000..53663bf82 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java @@ -0,0 +1,100 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Parenthesis; +import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.InExpression; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; + +/** + * Sql Parser remove Helper + */ +@Slf4j +public class SqlParserRemoveHelper { + + public static String removeWhereCondition(String sql, Set removeFieldNames) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + selectBody.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + removeWhereCondition(plainSelect.getWhere(), removeFieldNames); + } + }); + return selectStatement.toString(); + } + + private static void removeWhereCondition(Expression whereExpression, Set removeFieldNames) { + if (whereExpression == null) { + return; + } + removeWhereExpression(whereExpression, removeFieldNames); + } + + private static void removeWhereExpression(Expression whereExpression, Set removeFieldNames) { + if (SqlParserSelectHelper.isLogicExpression(whereExpression)) { + AndExpression andExpression = (AndExpression) whereExpression; + Expression leftExpression = andExpression.getLeftExpression(); + Expression rightExpression = andExpression.getRightExpression(); + + removeWhereExpression(leftExpression, removeFieldNames); + removeWhereExpression(rightExpression, removeFieldNames); + } else if (whereExpression instanceof Parenthesis) { + removeWhereExpression(((Parenthesis) whereExpression).getExpression(), removeFieldNames); + } else { + removeExpressionWithConstant(whereExpression, removeFieldNames); + } + } + + private static void removeExpressionWithConstant(Expression expression, Set removeFieldNames) { + if (expression instanceof EqualsTo) { + ComparisonOperator comparisonOperator = (ComparisonOperator) expression; + String columnName = SqlParserSelectHelper.getColumnName(comparisonOperator.getLeftExpression(), + comparisonOperator.getRightExpression()); + if (!removeFieldNames.contains(columnName)) { + return; + } + try { + ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression( + JsqlConstants.EQUAL_CONSTANT); + comparisonOperator.setLeftExpression(constantExpression.getLeftExpression()); + comparisonOperator.setRightExpression(constantExpression.getRightExpression()); + comparisonOperator.setASTNode(constantExpression.getASTNode()); + } catch (JSQLParserException e) { + log.error("JSQLParserException", e); + } + } + if (expression instanceof InExpression) { + InExpression inExpression = (InExpression) expression; + String columnName = SqlParserSelectHelper.getColumnName(inExpression.getLeftExpression(), + inExpression.getRightExpression()); + if (!removeFieldNames.contains(columnName)) { + return; + } + try { + InExpression constantExpression = (InExpression) CCJSqlParserUtil.parseCondExpression( + JsqlConstants.IN_CONSTANT); + inExpression.setLeftExpression(constantExpression.getLeftExpression()); + inExpression.setRightItemsList(constantExpression.getRightItemsList()); + inExpression.setASTNode(constantExpression.getASTNode()); + } catch (JSQLParserException e) { + log.error("JSQLParserException", e); + } + } + } + +} + diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java new file mode 100644 index 000000000..cc065894b --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java @@ -0,0 +1,244 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.statement.select.GroupByElement; +import net.sf.jsqlparser.statement.select.OrderByElement; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; + +/** + * Sql Parser replace Helper + */ +@Slf4j +public class SqlParserReplaceHelper { + + public static String replaceValue(String sql, Map> filedNameToValueMap) { + return replaceValue(sql, filedNameToValueMap, true); + } + + public static String replaceValue(String sql, Map> filedNameToValueMap, + boolean exactReplace) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + for (PlainSelect plainSelect : plainSelects) { + Expression where = plainSelect.getWhere(); + FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + } + return selectStatement.toString(); + } + + public static String replaceFieldNameByValue(String sql, Map> fieldValueToFieldNames) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + for (PlainSelect plainSelect : plainSelects) { + Expression where = plainSelect.getWhere(); + FiledNameReplaceVisitor visitor = new FiledNameReplaceVisitor(fieldValueToFieldNames); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + } + return selectStatement.toString(); + } + + public static String replaceFields(String sql, Map fieldNameMap) { + return replaceFields(sql, fieldNameMap, false); + } + + public static String replaceFields(String sql, Map fieldNameMap, boolean exactReplace) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + for (PlainSelect plainSelect : plainSelects) { + replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, plainSelect); + } + return selectStatement.toString(); + } + + private static void replaceFieldsInPlainOneSelect(Map fieldNameMap, boolean exactReplace, + PlainSelect plainSelect) { + //1. replace where fields + Expression where = plainSelect.getWhere(); + FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldNameMap, exactReplace); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + + //2. replace select fields + for (SelectItem selectItem : plainSelect.getSelectItems()) { + selectItem.accept(visitor); + } + + //3. replace oder by fields + List orderByElements = plainSelect.getOrderByElements(); + if (!CollectionUtils.isEmpty(orderByElements)) { + for (OrderByElement orderByElement : orderByElements) { + orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace)); + } + } + //4. replace group by fields + GroupByElement groupByElement = plainSelect.getGroupBy(); + if (Objects.nonNull(groupByElement)) { + groupByElement.accept(new GroupByReplaceVisitor(fieldNameMap, exactReplace)); + } + //5. replace having fields + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + having.accept(visitor); + } + } + + public static String replaceFunction(String sql, Map functionMap) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + for (PlainSelect plainSelect : plainSelects) { + replaceFunction(functionMap, plainSelect); + } + return selectStatement.toString(); + } + + private static void replaceFunction(Map functionMap, PlainSelect selectBody) { + PlainSelect plainSelect = selectBody; + //1. replace where dataDiff function + Expression where = plainSelect.getWhere(); + + FunctionNameReplaceVisitor visitor = new FunctionNameReplaceVisitor(functionMap); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + GroupByElement groupBy = plainSelect.getGroupBy(); + if (Objects.nonNull(groupBy)) { + GroupByFunctionReplaceVisitor replaceVisitor = new GroupByFunctionReplaceVisitor(functionMap); + groupBy.accept(replaceVisitor); + } + + for (SelectItem selectItem : plainSelect.getSelectItems()) { + selectItem.accept(visitor); + } + } + + public static String replaceFunction(String sql) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + List plainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) selectBody); + for (PlainSelect plainSelect : plainSelects) { + replaceFunction(plainSelect); + } + return selectStatement.toString(); + } + + private static void replaceFunction(PlainSelect selectBody) { + PlainSelect plainSelect = selectBody; + + //1. replace where dataDiff function + Expression where = plainSelect.getWhere(); + FunctionReplaceVisitor visitor = new FunctionReplaceVisitor(); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + //2. add Waiting Expression + List waitingForAdds = visitor.getWaitingForAdds(); + addWaitingExpression(plainSelect, where, waitingForAdds); + } + + private static void addWaitingExpression(PlainSelect plainSelect, Expression where, + List waitingForAdds) { + if (CollectionUtils.isEmpty(waitingForAdds)) { + return; + } + for (Expression expression : waitingForAdds) { + if (where == null) { + plainSelect.setWhere(expression); + } else { + where = new AndExpression(where, expression); + } + } + plainSelect.setWhere(where); + } + + public static String replaceTable(String sql, String tableName) { + if (StringUtils.isEmpty(tableName)) { + return sql; + } + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + PlainSelect plainSelect = (PlainSelect) selectBody; + // replace table name + List painSelects = SqlParserSelectHelper.getPlainSelects(plainSelect); + for (PlainSelect painSelect : painSelects) { + painSelect.accept( + new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + plainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); + } + }); + } + return selectStatement.toString(); + } + + public static String replaceAlias(String sql) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + FunctionAliasReplaceVisitor visitor = new FunctionAliasReplaceVisitor(); + for (SelectItem selectItem : plainSelect.getSelectItems()) { + selectItem.accept(visitor); + } + Map aliasToActualExpression = visitor.getAliasToActualExpression(); + if (Objects.nonNull(aliasToActualExpression) && !aliasToActualExpression.isEmpty()) { + return replaceFields(selectStatement.toString(), aliasToActualExpression, true); + } + return selectStatement.toString(); + } + + public static String replaceHavingValue(String sql, Map> filedNameToValueMap) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + Expression having = plainSelect.getHaving(); + FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(false, filedNameToValueMap); + if (Objects.nonNull(having)) { + having.accept(visitor); + } + return selectStatement.toString(); + } +} + diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java new file mode 100644 index 000000000..c07e2ee2e --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelper.java @@ -0,0 +1,76 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SelectItem; +import org.apache.commons.lang3.StringUtils; + +/** + * Sql Parser Select function Helper + */ +@Slf4j +public class SqlParserSelectFunctionHelper { + + public static boolean hasAggregateFunction(String sql) { + if (hasFunction(sql)) { + return true; + } + return SqlParserSelectHelper.hasGroupBy(sql); + } + + public static boolean hasFunction(String sql) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return false; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + List selectItems = plainSelect.getSelectItems(); + AggregateFunctionVisitor visitor = new AggregateFunctionVisitor(); + for (SelectItem selectItem : selectItems) { + selectItem.accept(visitor); + } + boolean selectFunction = visitor.hasAggregateFunction(); + if (selectFunction) { + return true; + } + return false; + } + + public static Function getFunction(Expression expression, Map fieldNameToAggregate) { + if (!(expression instanceof Column)) { + return null; + } + String columnName = ((Column) expression).getColumnName(); + if (StringUtils.isEmpty(columnName)) { + return null; + } + Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); + if (Objects.isNull(function)) { + return null; + } + return function; + } + + public static Function getFunction(Expression expression, String aggregateName) { + if (StringUtils.isEmpty(aggregateName)) { + return null; + } + Function sumFunction = new Function(); + sumFunction.setName(aggregateName); + sumFunction.setParameters(new ExpressionList(expression)); + return sumFunction; + } + +} + diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index 42d33bc97..3e9f92a7b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -8,7 +8,11 @@ import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.conditional.OrExpression; +import net.sf.jsqlparser.expression.operators.conditional.XorExpression; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; @@ -21,6 +25,8 @@ import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; +import net.sf.jsqlparser.statement.select.SubSelect; import org.springframework.util.CollectionUtils; /** @@ -109,6 +115,28 @@ public class SqlParserSelectHelper { return (Select) statement; } + public static List getPlainSelects(PlainSelect plainSelect) { + List plainSelects = new ArrayList<>(); + plainSelects.add(plainSelect); + plainSelect.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + Expression whereExpression = plainSelect.getWhere(); + if (whereExpression != null) { + whereExpression.accept(new ExpressionVisitorAdapter() { + @Override + public void visit(SubSelect subSelect) { + SelectBody subSelectBody = subSelect.getSelectBody(); + if (subSelectBody instanceof PlainSelect) { + plainSelects.add((PlainSelect) subSelectBody); + } + } + }); + } + } + }); + return plainSelects; + } public static List getAllFields(String sql) { @@ -244,14 +272,6 @@ public class SqlParserSelectHelper { return new ArrayList<>(result); } - - public static boolean hasAggregateFunction(String sql) { - if (hasFunction(sql)) { - return true; - } - return hasGroupBy(sql); - } - public static boolean hasGroupBy(String sql) { Select selectStatement = getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); @@ -269,24 +289,36 @@ public class SqlParserSelectHelper { return false; } - public static boolean hasFunction(String sql) { - Select selectStatement = getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); + public static boolean isLogicExpression(Expression whereExpression) { + return whereExpression instanceof AndExpression || (whereExpression instanceof OrExpression + || (whereExpression instanceof XorExpression)); + } - if (!(selectBody instanceof PlainSelect)) { - return false; + public static String getColumnName(Expression leftExpression, Expression rightExpression) { + if (leftExpression instanceof Column) { + return ((Column) leftExpression).getColumnName(); } - PlainSelect plainSelect = (PlainSelect) selectBody; - List selectItems = plainSelect.getSelectItems(); - AggregateFunctionVisitor visitor = new AggregateFunctionVisitor(); - for (SelectItem selectItem : selectItems) { - selectItem.accept(visitor); + if (rightExpression instanceof Column) { + return ((Column) rightExpression).getColumnName(); } - boolean selectFunction = visitor.hasAggregateFunction(); - if (selectFunction) { - return true; + return ""; + } + + public static String getColumnName(Expression leftExpression) { + if (leftExpression instanceof Column) { + Column leftColumnName = (Column) leftExpression; + return leftColumnName.getColumnName(); } - return false; + if (leftExpression instanceof Function) { + Function function = (Function) leftExpression; + if (!CollectionUtils.isEmpty(function.getParameters().getExpressions())) { + Expression expression = function.getParameters().getExpressions().get(0); + if (expression instanceof Column) { + return ((Column) expression).getColumnName(); + } + } + } + return ""; } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java deleted file mode 100644 index 28cf55336..000000000 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ /dev/null @@ -1,604 +0,0 @@ -package com.tencent.supersonic.common.util.jsqlparser; - -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.JSQLParserException; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.Function; -import net.sf.jsqlparser.expression.LongValue; -import net.sf.jsqlparser.expression.Parenthesis; -import net.sf.jsqlparser.expression.StringValue; -import net.sf.jsqlparser.expression.operators.conditional.AndExpression; -import net.sf.jsqlparser.expression.operators.conditional.OrExpression; -import net.sf.jsqlparser.expression.operators.conditional.XorExpression; -import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; -import net.sf.jsqlparser.expression.operators.relational.EqualsTo; -import net.sf.jsqlparser.expression.operators.relational.ExpressionList; -import net.sf.jsqlparser.expression.operators.relational.InExpression; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; -import net.sf.jsqlparser.schema.Column; -import net.sf.jsqlparser.schema.Table; -import net.sf.jsqlparser.statement.select.GroupByElement; -import net.sf.jsqlparser.statement.select.OrderByElement; -import net.sf.jsqlparser.statement.select.PlainSelect; -import net.sf.jsqlparser.statement.select.Select; -import net.sf.jsqlparser.statement.select.SelectBody; -import net.sf.jsqlparser.statement.select.SelectExpressionItem; -import net.sf.jsqlparser.statement.select.SelectItem; -import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; -import net.sf.jsqlparser.util.SelectUtils; -import org.apache.commons.lang3.StringUtils; -import org.springframework.util.CollectionUtils; - -/** - * Sql Parser Update Helper - */ -@Slf4j -public class SqlParserUpdateHelper { - - public static String replaceValue(String sql, Map> filedNameToValueMap) { - return replaceValue(sql, filedNameToValueMap, true); - } - - public static String replaceValue(String sql, Map> filedNameToValueMap, - boolean exactReplace) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - Expression where = plainSelect.getWhere(); - FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap); - if (Objects.nonNull(where)) { - where.accept(visitor); - } - return selectStatement.toString(); - } - - public static String replaceHavingValue(String sql, Map> filedNameToValueMap) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - Expression having = plainSelect.getHaving(); - FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(false, filedNameToValueMap); - if (Objects.nonNull(having)) { - having.accept(visitor); - } - return selectStatement.toString(); - } - - public static String replaceFieldNameByValue(String sql, Map> fieldValueToFieldNames) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - Expression where = plainSelect.getWhere(); - FiledNameReplaceVisitor visitor = new FiledNameReplaceVisitor(fieldValueToFieldNames); - if (Objects.nonNull(where)) { - where.accept(visitor); - } - return selectStatement.toString(); - } - - public static String replaceFields(String sql, Map fieldNameMap) { - return replaceFields(sql, fieldNameMap, false); - } - - public static String replaceFields(String sql, Map fieldNameMap, boolean exactReplace) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - //1. replace where fields - Expression where = plainSelect.getWhere(); - FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldNameMap, exactReplace); - if (Objects.nonNull(where)) { - where.accept(visitor); - } - - //2. replace select fields - for (SelectItem selectItem : plainSelect.getSelectItems()) { - selectItem.accept(visitor); - } - - //3. replace oder by fields - List orderByElements = plainSelect.getOrderByElements(); - if (!CollectionUtils.isEmpty(orderByElements)) { - for (OrderByElement orderByElement : orderByElements) { - orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace)); - } - } - - //4. replace group by fields - GroupByElement groupByElement = plainSelect.getGroupBy(); - if (Objects.nonNull(groupByElement)) { - groupByElement.accept(new GroupByReplaceVisitor(fieldNameMap, exactReplace)); - } - //5. replace having fields - Expression having = plainSelect.getHaving(); - if (Objects.nonNull(having)) { - having.accept(visitor); - } - return selectStatement.toString(); - } - - public static String replaceFunction(String sql, Map functionMap) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - //1. replace where dataDiff function - Expression where = plainSelect.getWhere(); - - FunctionNameReplaceVisitor visitor = new FunctionNameReplaceVisitor(functionMap); - if (Objects.nonNull(where)) { - where.accept(visitor); - } - GroupByElement groupBy = plainSelect.getGroupBy(); - if (Objects.nonNull(groupBy)) { - GroupByFunctionReplaceVisitor replaceVisitor = new GroupByFunctionReplaceVisitor(functionMap); - groupBy.accept(replaceVisitor); - } - - for (SelectItem selectItem : plainSelect.getSelectItems()) { - selectItem.accept(visitor); - } - return selectStatement.toString(); - } - - public static String replaceFunction(String sql) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - //1. replace where dataDiff function - Expression where = plainSelect.getWhere(); - FunctionReplaceVisitor visitor = new FunctionReplaceVisitor(); - if (Objects.nonNull(where)) { - where.accept(visitor); - } - //2. add Waiting Expression - List waitingForAdds = visitor.getWaitingForAdds(); - addWaitingExpression(plainSelect, where, waitingForAdds); - return selectStatement.toString(); - } - - private static void addWaitingExpression(PlainSelect plainSelect, Expression where, - List waitingForAdds) { - if (CollectionUtils.isEmpty(waitingForAdds)) { - return; - } - for (Expression expression : waitingForAdds) { - if (where == null) { - plainSelect.setWhere(expression); - } else { - where = new AndExpression(where, expression); - } - } - plainSelect.setWhere(where); - } - - - public static String addFieldsToSelect(String sql, List fields) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - // add fields to select - for (String field : fields) { - SelectUtils.addExpression(selectStatement, new Column(field)); - } - return selectStatement.toString(); - } - - public static String addFunctionToSelect(String sql, Expression expression) { - PlainSelect plainSelect = SqlParserSelectHelper.getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return sql; - } - List selectItems = plainSelect.getSelectItems(); - if (CollectionUtils.isEmpty(selectItems)) { - return sql; - } - boolean existFunction = false; - for (SelectItem selectItem : selectItems) { - SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem; - if (expressionItem.getExpression() instanceof Function) { - Function expressionFunction = (Function) expressionItem.getExpression(); - if (expression.toString().equalsIgnoreCase(expressionFunction.toString())) { - existFunction = true; - break; - } - } - } - if (!existFunction) { - SelectExpressionItem sumExpressionItem = new SelectExpressionItem(expression); - selectItems.add(sumExpressionItem); - } - return plainSelect.toString(); - } - - public static String replaceTable(String sql, String tableName) { - if (StringUtils.isEmpty(tableName)) { - return sql; - } - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - PlainSelect plainSelect = (PlainSelect) selectBody; - // replace table name - Table table = (Table) plainSelect.getFromItem(); - table.setName(tableName); - return selectStatement.toString(); - } - - - public static String replaceAlias(String sql) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - FunctionAliasReplaceVisitor visitor = new FunctionAliasReplaceVisitor(); - for (SelectItem selectItem : plainSelect.getSelectItems()) { - selectItem.accept(visitor); - } - Map aliasToActualExpression = visitor.getAliasToActualExpression(); - if (Objects.nonNull(aliasToActualExpression) && !aliasToActualExpression.isEmpty()) { - return replaceFields(selectStatement.toString(), aliasToActualExpression, true); - } - return selectStatement.toString(); - } - - public static String addWhere(String sql, String column, Object value) { - if (StringUtils.isEmpty(column) || Objects.isNull(value)) { - return sql; - } - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - Expression where = plainSelect.getWhere(); - - Expression right = new StringValue(value.toString()); - if (value instanceof Integer || value instanceof Long) { - right = new LongValue(value.toString()); - } - - if (where == null) { - plainSelect.setWhere(new EqualsTo(new Column(column), right)); - } else { - plainSelect.setWhere(new AndExpression(where, new EqualsTo(new Column(column), right))); - } - return selectStatement.toString(); - } - - - public static String addWhere(String sql, Expression expression) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - Expression where = plainSelect.getWhere(); - - if (where == null) { - plainSelect.setWhere(expression); - } else { - plainSelect.setWhere(new AndExpression(where, expression)); - } - return selectStatement.toString(); - } - - public static String addAggregateToField(String sql, Map fieldNameToAggregate) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - selectBody.accept(new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate); - addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate); - addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate); - addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate); - } - }); - return selectStatement.toString(); - } - - public static String addGroupBy(String sql, Set groupByFields) { - if (CollectionUtils.isEmpty(groupByFields)) { - return sql; - } - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - - PlainSelect plainSelect = (PlainSelect) selectBody; - GroupByElement groupByElement = new GroupByElement(); - List originalGroupByFields = SqlParserSelectHelper.getGroupByFields(sql); - if (!CollectionUtils.isEmpty(originalGroupByFields)) { - groupByFields.addAll(originalGroupByFields); - } - for (String groupByField : groupByFields) { - groupByElement.addGroupByExpression(new Column(groupByField)); - } - plainSelect.setGroupByElement(groupByElement); - return selectStatement.toString(); - } - - private static void addAggregateToSelectItems(List selectItems, - Map fieldNameToAggregate) { - for (SelectItem selectItem : selectItems) { - if (selectItem instanceof SelectExpressionItem) { - SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; - Expression expression = selectExpressionItem.getExpression(); - Function function = getFunction(expression, fieldNameToAggregate); - if (function == null) { - continue; - } - selectExpressionItem.setExpression(function); - } - } - } - - private static void addAggregateToOrderByItems(List orderByElements, - Map fieldNameToAggregate) { - if (orderByElements == null) { - return; - } - for (OrderByElement orderByElement : orderByElements) { - Expression expression = orderByElement.getExpression(); - Function function = getFunction(expression, fieldNameToAggregate); - if (function == null) { - continue; - } - orderByElement.setExpression(function); - } - } - - private static void addAggregateToGroupByItems(GroupByElement groupByElement, - Map fieldNameToAggregate) { - if (groupByElement == null) { - return; - } - for (Expression expression : groupByElement.getGroupByExpressions()) { - Function function = getFunction(expression, fieldNameToAggregate); - if (function == null) { - continue; - } - groupByElement.addGroupByExpression(function); - } - } - - private static void addAggregateToWhereItems(Expression whereExpression, Map fieldNameToAggregate) { - if (whereExpression == null) { - return; - } - modifyWhereExpression(whereExpression, fieldNameToAggregate); - } - - private static void modifyWhereExpression(Expression whereExpression, - Map fieldNameToAggregate) { - if (isLogicExpression(whereExpression)) { - AndExpression andExpression = (AndExpression) whereExpression; - Expression leftExpression = andExpression.getLeftExpression(); - Expression rightExpression = andExpression.getRightExpression(); - modifyWhereExpression(leftExpression, fieldNameToAggregate); - modifyWhereExpression(rightExpression, fieldNameToAggregate); - } else if (whereExpression instanceof Parenthesis) { - modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate); - } else { - setAggToFunction(whereExpression, fieldNameToAggregate); - } - } - - private static boolean isLogicExpression(Expression whereExpression) { - return whereExpression instanceof AndExpression || (whereExpression instanceof OrExpression - || (whereExpression instanceof XorExpression)); - } - - - private static void setAggToFunction(Expression expression, Map fieldNameToAggregate) { - if (!(expression instanceof ComparisonOperator)) { - return; - } - ComparisonOperator comparisonOperator = (ComparisonOperator) expression; - if (comparisonOperator.getRightExpression() instanceof Column) { - String columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName(); - Function function = getFunction(comparisonOperator.getRightExpression(), - fieldNameToAggregate.get(columnName)); - if (Objects.nonNull(function)) { - comparisonOperator.setRightExpression(function); - } - } - if (comparisonOperator.getLeftExpression() instanceof Column) { - String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName(); - Function function = getFunction(comparisonOperator.getLeftExpression(), - fieldNameToAggregate.get(columnName)); - if (Objects.nonNull(function)) { - comparisonOperator.setLeftExpression(function); - } - } - } - - - private static Function getFunction(Expression expression, Map fieldNameToAggregate) { - if (!(expression instanceof Column)) { - return null; - } - String columnName = ((Column) expression).getColumnName(); - if (StringUtils.isEmpty(columnName)) { - return null; - } - Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); - if (Objects.isNull(function)) { - return null; - } - return function; - } - - private static Function getFunction(Expression expression, String aggregateName) { - if (StringUtils.isEmpty(aggregateName)) { - return null; - } - Function sumFunction = new Function(); - sumFunction.setName(aggregateName); - sumFunction.setParameters(new ExpressionList(expression)); - return sumFunction; - } - - public static String addHaving(String sql, Set fieldNames) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - - PlainSelect plainSelect = (PlainSelect) selectBody; - //replace metric to 1 and 1 and add having metric - Expression where = plainSelect.getWhere(); - FiledFilterReplaceVisitor visitor = new FiledFilterReplaceVisitor(fieldNames); - if (Objects.nonNull(where)) { - where.accept(visitor); - } - List waitingForAdds = visitor.getWaitingForAdds(); - if (!CollectionUtils.isEmpty(waitingForAdds)) { - for (Expression waitingForAdd : waitingForAdds) { - Expression having = plainSelect.getHaving(); - if (Objects.isNull(having)) { - plainSelect.setHaving(waitingForAdd); - } else { - plainSelect.setHaving(new AndExpression(having, waitingForAdd)); - } - } - } - return selectStatement.toString(); - } - - public static String removeWhereCondition(String sql, Set removeFieldNames) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - selectBody.accept(new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - removeWhereCondition(plainSelect.getWhere(), removeFieldNames); - } - }); - return selectStatement.toString(); - } - - private static void removeWhereCondition(Expression whereExpression, Set removeFieldNames) { - if (whereExpression == null) { - return; - } - removeWhereExpression(whereExpression, removeFieldNames); - } - - private static void removeWhereExpression(Expression whereExpression, Set removeFieldNames) { - if (isLogicExpression(whereExpression)) { - AndExpression andExpression = (AndExpression) whereExpression; - Expression leftExpression = andExpression.getLeftExpression(); - Expression rightExpression = andExpression.getRightExpression(); - - removeWhereExpression(leftExpression, removeFieldNames); - removeWhereExpression(rightExpression, removeFieldNames); - } else if (whereExpression instanceof Parenthesis) { - removeWhereExpression(((Parenthesis) whereExpression).getExpression(), removeFieldNames); - } else { - removeExpressionWithConstant(whereExpression, removeFieldNames); - } - } - - private static void removeExpressionWithConstant(Expression expression, Set removeFieldNames) { - if (expression instanceof EqualsTo) { - ComparisonOperator comparisonOperator = (ComparisonOperator) expression; - String columnName = getColumnName(comparisonOperator.getLeftExpression(), - comparisonOperator.getRightExpression()); - if (!removeFieldNames.contains(columnName)) { - return; - } - try { - ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression( - JsqlConstants.EQUAL_CONSTANT); - comparisonOperator.setLeftExpression(constantExpression.getLeftExpression()); - comparisonOperator.setRightExpression(constantExpression.getRightExpression()); - comparisonOperator.setASTNode(constantExpression.getASTNode()); - } catch (JSQLParserException e) { - log.error("JSQLParserException", e); - } - } - if (expression instanceof InExpression) { - InExpression inExpression = (InExpression) expression; - String columnName = getColumnName(inExpression.getLeftExpression(), inExpression.getRightExpression()); - if (!removeFieldNames.contains(columnName)) { - return; - } - try { - InExpression constantExpression = (InExpression) CCJSqlParserUtil.parseCondExpression( - JsqlConstants.IN_CONSTANT); - inExpression.setLeftExpression(constantExpression.getLeftExpression()); - inExpression.setRightItemsList(constantExpression.getRightItemsList()); - inExpression.setASTNode(constantExpression.getASTNode()); - } catch (JSQLParserException e) { - log.error("JSQLParserException", e); - } - } - } - - private static String getColumnName(Expression leftExpression, Expression rightExpression) { - String columnName = ""; - if (leftExpression instanceof Column) { - columnName = ((Column) leftExpression).getColumnName(); - } - if (rightExpression instanceof Column) { - columnName = ((Column) rightExpression).getColumnName(); - } - return columnName; - } - - public static String addParenthesisToWhere(String sql) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - PlainSelect plainSelect = (PlainSelect) selectBody; - Expression where = plainSelect.getWhere(); - if (Objects.nonNull(where)) { - Parenthesis parenthesis = new Parenthesis(where); - plainSelect.setWhere(parenthesis); - } - return selectStatement.toString(); - } -} - diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/TableNameReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/TableNameReplaceVisitor.java new file mode 100644 index 000000000..361e8eb64 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/TableNameReplaceVisitor.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.statement.select.FromItemVisitorAdapter; + +public class TableNameReplaceVisitor extends FromItemVisitorAdapter { + + private String tableName; + + public TableNameReplaceVisitor(String tableName) { + this.tableName = tableName; + } + + @Override + public void visit(Table table) { + table.setName(tableName); + } +} \ No newline at end of file diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java new file mode 100644 index 000000000..02eb32de5 --- /dev/null +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java @@ -0,0 +1,278 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +/** + * SqlParserAddHelperTest Test + */ +class SqlParserAddHelperTest { + + @Test + void addWhere() throws JSQLParserException { + + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + sql = SqlParserAddHelper.addWhere(sql, "column_a", 123444555); + List selectFields = SqlParserSelectHelper.getAllFields(sql); + + Assert.assertEquals(selectFields.contains("column_a"), true); + + sql = SqlParserAddHelper.addWhere(sql, "column_b", "123456666"); + selectFields = SqlParserSelectHelper.getAllFields(sql); + + Assert.assertEquals(selectFields.contains("column_b"), true); + + Expression expression = CCJSqlParserUtil.parseCondExpression(" ( column_c = 111 or column_d = 1111)"); + + sql = SqlParserAddHelper.addWhere( + "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1", + expression); + + Assert.assertEquals(sql.contains("column_c = 111"), true); + + sql = "select 部门,sum (访问次数) from 超音数 where 用户 = alice or 发布日期 ='2023-07-03' group by 部门 limit 1"; + sql = SqlParserAddHelper.addParenthesisToWhere(sql); + sql = SqlParserAddHelper.addWhere(sql, "数据日期", "2023-08-08"); + Assert.assertEquals(sql, "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "(用户 = alice OR 发布日期 = '2023-07-03') AND 数据日期 = '2023-08-08' GROUP BY 部门 LIMIT 1"); + + } + + + @Test + void addFunctionToSelect() { + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " + + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') " + + "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + } + + @Test + void addAggregateToField() { + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " + + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') " + + "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + } + + + @Test + void addAggregateToMetricField() { + String sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' order by pv desc limit 10"; + + Map filedNameToAggregate = new HashMap<>(); + filedNameToAggregate.put("pv", "sum"); + + Set groupByFields = new HashSet<>(); + groupByFields.add("department"); + + String replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " + + "order by pv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where pv >1 order by pv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where sum(pv) >1 order by pv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and sum(pv) >1 " + + "GROUP BY department order by pv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " + + "GROUP BY department order by pv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 and department = 'HR' " + + "GROUP BY department order by pv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND sum(pv) > 1 " + + "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, pv from t_1 where (pv >1 and department = 'HR') " + + " and sys_imp_date = '2023-09-11' GROUP BY department order by pv desc limit 10"; + replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE (sum(pv) > 1 AND department = 'HR') AND " + + "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + @Test + void addGroupBy() { + String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' " + + "order by sum(pv) desc limit 10"; + + Set groupByFields = new HashSet<>(); + groupByFields.add("department"); + + String replaceSql = SqlParserAddHelper.addGroupBy(sql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, sum(pv) from t_1 where (department = 'HR') and sys_imp_date = '2023-09-11' " + + "order by sum(pv) desc limit 10"; + + replaceSql = SqlParserAddHelper.addGroupBy(sql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE (department = 'HR') AND sys_imp_date " + + "= '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + @Test + void addHaving() { + String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and " + + "sum(pv) > 2000 group by department order by sum(pv) desc limit 10"; + List groupByFields = new ArrayList<>(); + groupByFields.add("department"); + + Set fieldNames = new HashSet<>(); + fieldNames.add("pv"); + + String replaceSql = SqlParserAddHelper.addHaving(sql, fieldNames); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 " + + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + + sql = "select department, sum(pv) from t_1 where (sum(pv) > 2000) and sys_imp_date = '2023-09-11' " + + "group by department order by sum(pv) desc limit 10"; + + replaceSql = SqlParserAddHelper.addHaving(sql, fieldNames); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE (2 > 1) AND sys_imp_date = '2023-09-11' " + + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + + @Test + void addParenthesisToWhere() { + String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + String replaceSql = SqlParserAddHelper.addParenthesisToWhere(sql); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01') " + + "ORDER BY 播放量 DESC LIMIT 11", + replaceSql); + } + +} diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java new file mode 100644 index 000000000..baa6e7588 --- /dev/null +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java @@ -0,0 +1,50 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.HashSet; +import java.util.Set; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +/** + * SqlParser Remove Helper Test + */ +class SqlParserRemoveHelperTest { + + @Test + void removeWhereCondition() { + String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + Set removeFieldNames = new HashSet<>(); + removeFieldNames.add("歌曲名"); + + String replaceSql = SqlParserRemoveHelper.removeWhereCondition(sql, removeFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + + "ORDER BY 播放量 DESC LIMIT 11", + replaceSql); + + sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋') and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + replaceSql = SqlParserRemoveHelper.removeWhereCondition(sql, removeFieldNames); + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND " + + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", + replaceSql); + + sql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋')) and 数据日期 = '2023-08-09' " + + " order by 播放量 desc limit 11"; + replaceSql = SqlParserRemoveHelper.removeWhereCondition(sql, removeFieldNames); + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 1 IN (1) AND 1 IN (1)) AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", + replaceSql); + } + +} diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java new file mode 100644 index 000000000..7bdb9da5f --- /dev/null +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java @@ -0,0 +1,384 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +/** + * SqlParserReplaceHelperTest + */ +class SqlParserReplaceHelperTest { + + @Test + void replaceValue() { + + String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '杰伦' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + Map> filedNameToValueMap = new HashMap<>(); + + Map valueMap = new HashMap<>(); + valueMap.put("杰伦", "周杰伦"); + filedNameToValueMap.put("歌手名", valueMap); + + replaceSql = SqlParserReplaceHelper.replaceValue(replaceSql, filedNameToValueMap); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND " + + "歌手名 = '周杰伦' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + Map> filedNameToValueMap2 = new HashMap<>(); + + Map valueMap2 = new HashMap<>(); + valueMap2.put("周杰伦", "周杰伦"); + valueMap2.put("林俊杰", "林俊杰"); + valueMap2.put("陈奕迅", "陈奕迅"); + filedNameToValueMap2.put("歌手名", valueMap2); + + replaceSql = SqlParserReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " + + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 数据日期 = '2023-08-09' AND " + + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09' " + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " + + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 歌曲发布时 = '2023-08-01') " + + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01' and 播放量 < (" + + "select min(播放量) from 歌曲库 where 语种 = '英文' " + + ") ) and 数据日期 = '2023-08-09' " + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND " + + "歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 歌曲发布时 = '2023-08-01' " + + "AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) AND 数据日期 = '2023-08-09' " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); + } + + + @Test + void replaceFieldNameByValue() { + + String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + Map> fieldValueToFieldNames = new HashMap<>(); + fieldValueToFieldNames.put("邓紫棋", Collections.singleton("歌手名")); + + replaceSql = SqlParserReplaceHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND " + + "歌手名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 like '%邓紫棋%' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserReplaceHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 歌手名 LIKE '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " + + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + Set fieldNames = new HashSet<>(); + fieldNames.add("歌手名"); + fieldNames.add("歌曲名"); + fieldNames.add("专辑名"); + + fieldValueToFieldNames.put("林俊杰", fieldNames); + replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '林俊杰' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserReplaceHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 歌手名 = '林俊杰' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " + + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '林俊杰' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09'" + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserReplaceHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND " + + "歌手名 = '林俊杰' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + } + + @Test + void replaceFields() { + + Map fieldToBizName = initParams(); + String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' " + + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND " + + "song_publis_date = '2023-08-01' AND publish_date >= '2023-08-08' " + + "ORDER BY play_count DESC LIMIT 11", replaceSql); + + replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " + + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " + + "group by MONTH(数据日期) order by sum(访问次数) desc limit 1"; + + replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT MONTH(sys_imp_date), sum(pv) FROM 内容库产品 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-03-03' " + + "GROUP BY MONTH(sys_imp_date) ORDER BY sum(pv) DESC LIMIT 1", replaceSql); + + replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " + + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " + + "group by MONTH(数据日期) HAVING sum(访问次数) > 1000"; + + replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT MONTH(sys_imp_date), sum(pv) FROM 内容库产品 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-03-03' GROUP BY MONTH(sys_imp_date) HAVING sum(pv) > 1000", + replaceSql); + + replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) " + + "in (2022, 2023) and 数据日期 = '2023-08-14' group by YEAR(发行日期)"; + + replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 " + + "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' " + + "GROUP BY YEAR(publish_date)", + replaceSql); + + replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 " + + "where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' " + + "group by 发行日期"; + + replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 " + + "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14'" + + " GROUP BY publish_date", + replaceSql); + + replaceSql = SqlParserReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-11') <= 1 " + + "and 结算播放量 > 1000000 and datediff('day', 数据日期, '2023-08-11') <= 30", + fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-11' " + + "AND play_count > 1000000 AND sys_imp_date <= '2023-08-11' AND " + + "publish_date >= '2022-08-11' AND sys_imp_date >= '2023-07-12'", replaceSql); + + replaceSql = SqlParserReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", + fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' " + + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' " + + "AND publish_date >= '2023-08-08' ORDER BY play_count DESC LIMIT 11", replaceSql); + + replaceSql = SqlParserReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') = 0 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE 1 = 1 AND singer_name = '邓紫棋'" + + " AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-08-09' " + + "AND publish_date >= '2023-01-01' ORDER BY play_count DESC LIMIT 11", replaceSql); + + replaceSql = SqlParserReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') <= 0.5 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' " + + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' " + + "AND publish_date >= '2023-02-09' ORDER BY play_count DESC LIMIT 11", replaceSql); + + replaceSql = SqlParserReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date >= '2023-08-09' " + + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' " + + "AND publish_date <= '2023-02-09' ORDER BY play_count DESC LIMIT 11", replaceSql); + + replaceSql = SqlParserReplaceHelper.replaceFields( + "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice'" + + " and 发布日期 ='11' order by 访问次数 desc limit 1", fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT department, user_id FROM 超音数 WHERE sys_imp_date = '2023-08-08'" + + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1", replaceSql); + + replaceSql = SqlParserReplaceHelper.replaceTable(replaceSql, "s2"); + + replaceSql = SqlParserAddHelper.addFieldsToSelect(replaceSql, Collections.singletonList("field_a")); + + replaceSql = SqlParserReplaceHelper.replaceFields( + "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 ='alice' and 发布日期 ='11' group by 部门 limit 1", fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08'" + + " AND user_id = 'alice' AND publish_date = '11' GROUP BY department LIMIT 1", replaceSql); + + replaceSql = "select sum(访问次数) from 超音数 where 数据日期 >= '2023-08-06' " + + "and 数据日期 <= '2023-08-06' and 部门 = 'hr'"; + replaceSql = SqlParserReplaceHelper.replaceFields(replaceSql, fieldToBizName); + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT sum(pv) FROM 超音数 WHERE sys_imp_date >= '2023-08-06' " + + "AND sys_imp_date <= '2023-08-06' AND department = 'hr'", replaceSql); + } + + + @Test + void replaceTable() { + + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + String replaceSql = SqlParserReplaceHelper.replaceTable(sql, "s2"); + + Assert.assertEquals( + "SELECT 部门, sum(访问次数) FROM s2 WHERE 数据日期 = '2023-08-08' " + + "AND 用户 = alice AND 发布日期 = '11' GROUP BY 部门 LIMIT 1", replaceSql); + + sql = "SELECT * FROM CSpider音乐 WHERE (评分 < (SELECT min(评分) " + + "FROM CSpider音乐 WHERE 语种 = '英文')) AND 数据日期 = '2023-10-11'"; + replaceSql = SqlParserReplaceHelper.replaceTable(sql, "cspider"); + + Assert.assertEquals( + "SELECT * FROM cspider WHERE (评分 < (SELECT min(评分) FROM " + + "cspider WHERE 语种 = '英文')) AND 数据日期 = '2023-10-11'", replaceSql); + } + + @Test + void replaceFunctionName() { + + String sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; + Map functionMap = new HashMap<>(); + functionMap.put("MONTH".toLowerCase(), "toMonth"); + String replaceSql = SqlParserReplaceHelper.replaceFunction(sql, functionMap); + + Assert.assertEquals( + "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" + + " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)", + replaceSql); + + sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; + replaceSql = SqlParserReplaceHelper.replaceFunction(sql, functionMap); + + Assert.assertEquals( + "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" + + " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)", + replaceSql); + + sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + + " (datediff('month', 数据日期, '2023-09-02') <= 6) and 数据日期 = '2023-10-10' group by MONTH(数据日期)"; + replaceSql = SqlParserReplaceHelper.replaceFunction(sql, functionMap); + + Assert.assertEquals( + "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" + + " (datediff('month', 数据日期, '2023-09-02') <= 6) AND " + + "数据日期 = '2023-10-10' GROUP BY toMonth(数据日期)", + replaceSql); + } + + @Test + void replaceAlias() { + String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; + String replaceSql = SqlParserReplaceHelper.replaceAlias(sql); + System.out.println(replaceSql); + Assert.assertEquals( + "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", + replaceSql); + + sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + + "group by 部门 order by 总访问次数 desc limit 10"; + replaceSql = SqlParserReplaceHelper.replaceAlias(sql); + System.out.println(replaceSql); + Assert.assertEquals( + "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " + + "GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", + replaceSql); + + } + + private Map initParams() { + Map fieldToBizName = new HashMap<>(); + fieldToBizName.put("部门", "department"); + fieldToBizName.put("用户", "user_id"); + fieldToBizName.put("数据日期", "sys_imp_date"); + fieldToBizName.put("发布日期", "publish_date"); + fieldToBizName.put("访问次数", "pv"); + fieldToBizName.put("歌曲名", "song_name"); + fieldToBizName.put("歌手名", "singer_name"); + fieldToBizName.put("播放", "play_count"); + fieldToBizName.put("歌曲发布时间", "song_publis_date"); + fieldToBizName.put("歌曲发布年份", "song_publis_year"); + fieldToBizName.put("访问次数", "pv"); + return fieldToBizName; + } +} diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelperTest.java new file mode 100644 index 000000000..0887e7921 --- /dev/null +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectFunctionHelperTest.java @@ -0,0 +1,46 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import net.sf.jsqlparser.JSQLParserException; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +/** + * SqlParserSelectHelper Test + */ +class SqlParserSelectFunctionHelperTest { + + @Test + void hasAggregateFunction() throws JSQLParserException { + + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + boolean hasAggregateFunction = SqlParserSelectFunctionHelper.hasAggregateFunction(sql); + + Assert.assertEquals(hasAggregateFunction, true); + sql = "select 部门,count (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + hasAggregateFunction = SqlParserSelectFunctionHelper.hasAggregateFunction(sql); + Assert.assertEquals(hasAggregateFunction, true); + + sql = "SELECT count(1) FROM s2 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice'" + + " AND publish_date = '11' ORDER BY pv DESC LIMIT 1"; + hasAggregateFunction = SqlParserSelectFunctionHelper.hasAggregateFunction(sql); + Assert.assertEquals(hasAggregateFunction, true); + + sql = "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"; + hasAggregateFunction = SqlParserSelectFunctionHelper.hasAggregateFunction(sql); + Assert.assertEquals(hasAggregateFunction, false); + + sql = "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" + + " AND user_id = 'alice' AND publish_date = '11'"; + hasAggregateFunction = SqlParserSelectFunctionHelper.hasAggregateFunction(sql); + Assert.assertEquals(hasAggregateFunction, false); + + sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10"; + hasAggregateFunction = SqlParserSelectFunctionHelper.hasAggregateFunction(sql); + Assert.assertEquals(hasAggregateFunction, true); + } + +} diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index e5c24e4fb..80def539e 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -1,9 +1,6 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -13,7 +10,6 @@ import org.junit.jupiter.api.Test; */ class SqlParserSelectHelperTest { - @Test void getWhereFilterExpression() { @@ -189,56 +185,6 @@ class SqlParserSelectHelperTest { Assert.assertEquals(selectFields.contains("pv"), true); } - @Test - void hasAggregateFunction() throws JSQLParserException { - - String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; - boolean hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql); - - Assert.assertEquals(hasAggregateFunction, true); - sql = "select 部门,count (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; - hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql); - Assert.assertEquals(hasAggregateFunction, true); - - sql = "SELECT count(1) FROM s2 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice'" - + " AND publish_date = '11' ORDER BY pv DESC LIMIT 1"; - hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql); - Assert.assertEquals(hasAggregateFunction, true); - - sql = "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " - + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"; - hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql); - Assert.assertEquals(hasAggregateFunction, false); - - sql = "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11'"; - hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql); - Assert.assertEquals(hasAggregateFunction, false); - - sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10"; - hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql); - Assert.assertEquals(hasAggregateFunction, true); - } - - private Map initParams() { - Map fieldToBizName = new HashMap<>(); - fieldToBizName.put("部门", "department"); - fieldToBizName.put("用户", "user_id"); - fieldToBizName.put("数据日期", "sys_imp_date"); - fieldToBizName.put("发布日期", "publish_date"); - fieldToBizName.put("访问次数", "pv"); - fieldToBizName.put("歌曲名", "song_name"); - fieldToBizName.put("歌手名", "singer_name"); - fieldToBizName.put("播放", "play_count"); - fieldToBizName.put("歌曲发布时间", "song_publis_date"); - fieldToBizName.put("歌曲发布年份", "song_publis_year"); - fieldToBizName.put("转3.0前后30天结算份额衰减", "fdafdfdsa_fdas"); - return fieldToBizName; - } - @Test void getGroupByFields() { diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java deleted file mode 100644 index ee5bd7e74..000000000 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ /dev/null @@ -1,662 +0,0 @@ -package com.tencent.supersonic.common.util.jsqlparser; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import net.sf.jsqlparser.JSQLParserException; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; -import org.junit.Assert; -import org.junit.jupiter.api.Test; - -/** - * SqlParserUpdateHelper Test - */ -class SqlParserUpdateHelperTest { - - @Test - void replaceValue() { - - String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '杰伦' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - Map> filedNameToValueMap = new HashMap<>(); - - Map valueMap = new HashMap<>(); - valueMap.put("杰伦", "周杰伦"); - filedNameToValueMap.put("歌手名", valueMap); - - replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND " - + "歌手名 = '周杰伦' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " - + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); - - replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - Map> filedNameToValueMap2 = new HashMap<>(); - - Map valueMap2 = new HashMap<>(); - valueMap2.put("周杰伦", "周杰伦"); - valueMap2.put("林俊杰", "林俊杰"); - valueMap2.put("陈奕迅", "陈奕迅"); - filedNameToValueMap2.put("歌手名", valueMap2); - - replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap2, false); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " - + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 数据日期 = '2023-08-09' AND " - + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); - - replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09' " - + " order by 播放量 desc limit 11"; - - replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap2, false); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " - + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 歌曲发布时 = '2023-08-01') " - + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); - - } - - - @Test - void replaceFieldNameByValue() { - - String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - Map> fieldValueToFieldNames = new HashMap<>(); - fieldValueToFieldNames.put("邓紫棋", Collections.singleton("歌手名")); - - replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND " - + "歌手名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " - + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); - - replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 like '%邓紫棋%' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 歌手名 LIKE '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " - + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); - - Set fieldNames = new HashSet<>(); - fieldNames.add("歌手名"); - fieldNames.add("歌曲名"); - fieldNames.add("专辑名"); - - fieldValueToFieldNames.put("林俊杰", fieldNames); - replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '林俊杰' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 歌手名 = '林俊杰' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " - + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); - - replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '林俊杰' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09'" - + " order by 播放量 desc limit 11"; - - replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND " - + "歌手名 = '林俊杰' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' " - + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); - - } - - @Test - void replaceFields() { - - Map fieldToBizName = initParams(); - String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - replaceSql = SqlParserUpdateHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' " - + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND " - + "song_publis_date = '2023-08-01' AND publish_date >= '2023-08-08' " - + "ORDER BY play_count DESC LIMIT 11", replaceSql); - - replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " - + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " - + "group by MONTH(数据日期) order by sum(访问次数) desc limit 1"; - - replaceSql = SqlParserUpdateHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT MONTH(sys_imp_date), sum(pv) FROM 内容库产品 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-03-03' " - + "GROUP BY MONTH(sys_imp_date) ORDER BY sum(pv) DESC LIMIT 1", replaceSql); - - replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " - + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " - + "group by MONTH(数据日期) HAVING sum(访问次数) > 1000"; - - replaceSql = SqlParserUpdateHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT MONTH(sys_imp_date), sum(pv) FROM 内容库产品 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-03-03' GROUP BY MONTH(sys_imp_date) HAVING sum(pv) > 1000", - replaceSql); - - replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) " - + "in (2022, 2023) and 数据日期 = '2023-08-14' group by YEAR(发行日期)"; - - replaceSql = SqlParserUpdateHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 " - + "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' " - + "GROUP BY YEAR(publish_date)", - replaceSql); - - replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 " - + "where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' " - + "group by 发行日期"; - - replaceSql = SqlParserUpdateHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 " - + "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14'" - + " GROUP BY publish_date", - replaceSql); - - replaceSql = SqlParserUpdateHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-11') <= 1 " - + "and 结算播放量 > 1000000 and datediff('day', 数据日期, '2023-08-11') <= 30", - fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-11' " - + "AND play_count > 1000000 AND sys_imp_date <= '2023-08-11' AND " - + "publish_date >= '2022-08-11' AND sys_imp_date >= '2023-07-12'", replaceSql); - - replaceSql = SqlParserUpdateHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' " - + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' " - + "AND publish_date >= '2023-08-08' ORDER BY play_count DESC LIMIT 11", replaceSql); - - replaceSql = SqlParserUpdateHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') = 0 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE 1 = 1 AND singer_name = '邓紫棋'" - + " AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-08-09' " - + "AND publish_date >= '2023-01-01' ORDER BY play_count DESC LIMIT 11", replaceSql); - - replaceSql = SqlParserUpdateHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') <= 0.5 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' " - + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' " - + "AND publish_date >= '2023-02-09' ORDER BY play_count DESC LIMIT 11", replaceSql); - - replaceSql = SqlParserUpdateHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE publish_date >= '2023-08-09' " - + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' " - + "AND publish_date <= '2023-02-09' ORDER BY play_count DESC LIMIT 11", replaceSql); - - replaceSql = SqlParserUpdateHelper.replaceFields( - "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice'" - + " and 发布日期 ='11' order by 访问次数 desc limit 1", fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT department, user_id FROM 超音数 WHERE sys_imp_date = '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1", replaceSql); - - replaceSql = SqlParserUpdateHelper.replaceTable(replaceSql, "s2"); - - replaceSql = SqlParserUpdateHelper.addFieldsToSelect(replaceSql, Collections.singletonList("field_a")); - - replaceSql = SqlParserUpdateHelper.replaceFields( - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 ='alice' and 发布日期 ='11' group by 部门 limit 1", fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11' GROUP BY department LIMIT 1", replaceSql); - - replaceSql = "select sum(访问次数) from 超音数 where 数据日期 >= '2023-08-06' " - + "and 数据日期 <= '2023-08-06' and 部门 = 'hr'"; - replaceSql = SqlParserUpdateHelper.replaceFields(replaceSql, fieldToBizName); - replaceSql = SqlParserUpdateHelper.replaceFunction(replaceSql); - - Assert.assertEquals( - "SELECT sum(pv) FROM 超音数 WHERE sys_imp_date >= '2023-08-06' " - + "AND sys_imp_date <= '2023-08-06' AND department = 'hr'", replaceSql); - } - - - @Test - void replaceTable() { - - String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; - String replaceSql = SqlParserUpdateHelper.replaceTable(sql, "s2"); - - Assert.assertEquals( - "SELECT 部门, sum(访问次数) FROM s2 WHERE 数据日期 = '2023-08-08' " - + "AND 用户 = alice AND 发布日期 = '11' GROUP BY 部门 LIMIT 1", replaceSql); - } - - @Test - void addWhere() throws JSQLParserException { - - String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; - sql = SqlParserUpdateHelper.addWhere(sql, "column_a", 123444555); - List selectFields = SqlParserSelectHelper.getAllFields(sql); - - Assert.assertEquals(selectFields.contains("column_a"), true); - - sql = SqlParserUpdateHelper.addWhere(sql, "column_b", "123456666"); - selectFields = SqlParserSelectHelper.getAllFields(sql); - - Assert.assertEquals(selectFields.contains("column_b"), true); - - Expression expression = CCJSqlParserUtil.parseCondExpression(" ( column_c = 111 or column_d = 1111)"); - - sql = SqlParserUpdateHelper.addWhere( - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1", - expression); - - Assert.assertEquals(sql.contains("column_c = 111"), true); - - sql = "select 部门,sum (访问次数) from 超音数 where 用户 = alice or 发布日期 ='2023-07-03' group by 部门 limit 1"; - sql = SqlParserUpdateHelper.addParenthesisToWhere(sql); - sql = SqlParserUpdateHelper.addWhere(sql, "数据日期", "2023-08-08"); - Assert.assertEquals(sql, "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " - + "(用户 = alice OR 发布日期 = '2023-07-03') AND 数据日期 = '2023-08-08' GROUP BY 部门 LIMIT 1"); - - } - - @Test - void replaceFunctionName() { - - String sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" - + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; - Map functionMap = new HashMap<>(); - functionMap.put("MONTH".toLowerCase(), "toMonth"); - String replaceSql = SqlParserUpdateHelper.replaceFunction(sql, functionMap); - - Assert.assertEquals( - "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" - + " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)", - replaceSql); - - sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" - + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; - replaceSql = SqlParserUpdateHelper.replaceFunction(sql, functionMap); - - Assert.assertEquals( - "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" - + " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)", - replaceSql); - - sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" - + " (datediff('month', 数据日期, '2023-09-02') <= 6) and 数据日期 = '2023-10-10' group by MONTH(数据日期)"; - replaceSql = SqlParserUpdateHelper.replaceFunction(sql, functionMap); - - Assert.assertEquals( - "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" - + " (datediff('month', 数据日期, '2023-09-02') <= 6) AND " - + "数据日期 = '2023-10-10' GROUP BY toMonth(数据日期)", - replaceSql); - } - - @Test - void replaceAlias() { - String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " - + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; - String replaceSql = SqlParserUpdateHelper.replaceAlias(sql); - System.out.println(replaceSql); - Assert.assertEquals( - "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " - + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", - replaceSql); - - sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " - + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " - + "group by 部门 order by 总访问次数 desc limit 10"; - replaceSql = SqlParserUpdateHelper.replaceAlias(sql); - System.out.println(replaceSql); - Assert.assertEquals( - "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " - + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " - + "GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", - replaceSql); - - } - - @Test - void addFunctionToSelect() { - String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); - - String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); - System.out.println(replaceSql); - Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", - replaceSql); - - sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - havingExpression = SqlParserSelectHelper.getHavingExpression(sql); - - replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); - System.out.println(replaceSql); - Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", - replaceSql); - - sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " - + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - havingExpression = SqlParserSelectHelper.getHavingExpression(sql); - - replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); - System.out.println(replaceSql); - Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') " - + "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", - replaceSql); - - } - - @Test - void addAggregateToField() { - String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); - - String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); - System.out.println(replaceSql); - Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", - replaceSql); - - sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - havingExpression = SqlParserSelectHelper.getHavingExpression(sql); - - replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); - System.out.println(replaceSql); - Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", - replaceSql); - - sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " - + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; - havingExpression = SqlParserSelectHelper.getHavingExpression(sql); - - replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); - System.out.println(replaceSql); - Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') " - + "AND sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", - replaceSql); - } - - - @Test - void addAggregateToMetricField() { - String sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' order by pv desc limit 10"; - - Map filedNameToAggregate = new HashMap<>(); - filedNameToAggregate.put("pv", "sum"); - - Set groupByFields = new HashSet<>(); - groupByFields.add("department"); - - String replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); - replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " - + "order by pv desc limit 10"; - replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); - replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, pv from t_1 where pv >1 order by pv desc limit 10"; - replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); - replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " - + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, pv from t_1 where sum(pv) >1 order by pv desc limit 10"; - replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); - replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " - + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and sum(pv) >1 " - + "GROUP BY department order by pv desc limit 10"; - replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); - replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " - + "GROUP BY department order by pv desc limit 10"; - replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); - replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 and department = 'HR' " - + "GROUP BY department order by pv desc limit 10"; - replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); - replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND sum(pv) > 1 " - + "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, pv from t_1 where (pv >1 and department = 'HR') " - + " and sys_imp_date = '2023-09-11' GROUP BY department order by pv desc limit 10"; - replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); - replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE (sum(pv) > 1 AND department = 'HR') AND " - + "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - } - - @Test - void addGroupBy() { - String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' " - + "order by sum(pv) desc limit 10"; - - Set groupByFields = new HashSet<>(); - groupByFields.add("department"); - - String replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, sum(pv) from t_1 where (department = 'HR') and sys_imp_date = '2023-09-11' " - + "order by sum(pv) desc limit 10"; - - replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE (department = 'HR') AND sys_imp_date " - + "= '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - } - - @Test - void addHaving() { - String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and " - + "sum(pv) > 2000 group by department order by sum(pv) desc limit 10"; - List groupByFields = new ArrayList<>(); - groupByFields.add("department"); - - Set fieldNames = new HashSet<>(); - fieldNames.add("pv"); - - String replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND 2 > 1 " - + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - - sql = "select department, sum(pv) from t_1 where (sum(pv) > 2000) and sys_imp_date = '2023-09-11' " - + "group by department order by sum(pv) desc limit 10"; - - replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames); - - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE (2 > 1) AND sys_imp_date = '2023-09-11' " - + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); - } - - @Test - void removeWhereCondition() { - String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - Set removeFieldNames = new HashSet<>(); - removeFieldNames.add("歌曲名"); - - String replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " - + "ORDER BY 播放量 DESC LIMIT 11", - replaceSql); - - sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋') and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND " - + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", - replaceSql); - - sql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋')) and 数据日期 = '2023-08-09' " - + " order by 播放量 desc limit 11"; - replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 1 IN (1) AND 1 IN (1)) AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", - replaceSql); - } - - - @Test - void addParenthesisToWhere() { - String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; - - String replaceSql = SqlParserUpdateHelper.addParenthesisToWhere(sql); - - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01') " - + "ORDER BY 播放量 DESC LIMIT 11", - replaceSql); - } - - private Map initParams() { - Map fieldToBizName = new HashMap<>(); - fieldToBizName.put("部门", "department"); - fieldToBizName.put("用户", "user_id"); - fieldToBizName.put("数据日期", "sys_imp_date"); - fieldToBizName.put("发布日期", "publish_date"); - fieldToBizName.put("访问次数", "pv"); - fieldToBizName.put("歌曲名", "song_name"); - fieldToBizName.put("歌手名", "singer_name"); - fieldToBizName.put("播放", "play_count"); - fieldToBizName.put("歌曲发布时间", "song_publis_date"); - fieldToBizName.put("歌曲发布年份", "song_publis_year"); - fieldToBizName.put("访问次数", "pv"); - return fieldToBizName; - } -} diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/adaptor/engineadapter/ClickHouseAdaptor.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/adaptor/engineadapter/ClickHouseAdaptor.java index f4e8b4b9a..5c35de462 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/adaptor/engineadapter/ClickHouseAdaptor.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/adaptor/engineadapter/ClickHouseAdaptor.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.semantic.model.domain.adaptor.engineadapter; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.Constants; import java.util.HashMap; @@ -49,7 +49,7 @@ public class ClickHouseAdaptor extends EngineAdaptor { functionMap.put("MONTH".toLowerCase(), "toMonth"); functionMap.put("DAY".toLowerCase(), "toDayOfMonth"); functionMap.put("YEAR".toLowerCase(), "toYear"); - return SqlParserUpdateHelper.replaceFunction(sql, functionMap); + return SqlParserReplaceHelper.replaceFunction(sql, functionMap); } @Override diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java index 2c2c6592a..aa4e297a7 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java @@ -1,8 +1,8 @@ package com.tencent.supersonic.semantic.query.parser.convert; import com.tencent.supersonic.common.util.DateUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import com.tencent.supersonic.semantic.api.model.request.SqlExecuteReq; @@ -108,7 +108,7 @@ public class QueryReqConverter { Map fieldNameToBizNameMap = getFieldNameToBizNameMap(modelSchemaResp); String sql = databaseReq.getSql(); log.info("convert name to bizName before:{}", sql); - String replaceFields = SqlParserUpdateHelper.replaceFields(sql, fieldNameToBizNameMap, false); + String replaceFields = SqlParserReplaceHelper.replaceFields(sql, fieldNameToBizNameMap, false); log.info("convert name to bizName after:{}", replaceFields); databaseReq.setSql(replaceFields); } @@ -159,7 +159,7 @@ public class QueryReqConverter { } public void correctTableName(QueryDslReq databaseReq) { - String sql = SqlParserUpdateHelper.replaceTable(databaseReq.getSql(), TABLE_PREFIX + databaseReq.getModelId()); + String sql = SqlParserReplaceHelper.replaceTable(databaseReq.getSql(), TABLE_PREFIX + databaseReq.getModelId()); databaseReq.setSql(sql); } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java index 0fb93e8ca..eaa70367e 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java @@ -1,11 +1,13 @@ package com.tencent.supersonic.semantic.query.utils; +import static com.tencent.supersonic.common.pojo.Constants.MINUS; + import com.google.common.base.Strings; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; @@ -13,6 +15,13 @@ import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.model.domain.DimensionService; import com.tencent.supersonic.semantic.model.domain.ModelService; import com.tencent.supersonic.semantic.query.service.AuthCommonService; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.StringJoiner; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; @@ -28,17 +37,6 @@ import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import java.util.StringJoiner; -import java.util.Objects; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; -import java.util.HashSet; - -import java.util.stream.Collectors; - -import static com.tencent.supersonic.common.pojo.Constants.MINUS; - @Component @Aspect @Order(1) @@ -147,7 +145,7 @@ public class DslDataAspect { try { Expression expression = CCJSqlParserUtil.parseCondExpression(" ( " + joiner.toString() + " ) "); if (StringUtils.isNotEmpty(joiner.toString())) { - String sql = SqlParserUpdateHelper.addWhere(queryDslReq.getSql(), expression); + String sql = SqlParserAddHelper.addWhere(queryDslReq.getSql(), expression); log.info("before doRowPermission, queryDslReq:{}", queryDslReq.getSql()); queryDslReq.setSql(sql); log.info("after doRowPermission, queryDslReq:{}", queryDslReq.getSql());