diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledNameReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledNameReplaceVisitor.java index fed24f0fd..cf3c958d5 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledNameReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledNameReplaceVisitor.java @@ -3,15 +3,18 @@ package com.tencent.supersonic.common.util.jsqlparser; import java.util.Map; import java.util.Objects; import java.util.Set; +import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.LikeExpression; import net.sf.jsqlparser.schema.Column; import org.springframework.util.CollectionUtils; public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter { + public static final String PREFIX = "%"; private Map> fieldValueToFieldNames; public FiledNameReplaceVisitor(Map> fieldValueToFieldNames) { @@ -20,6 +23,15 @@ public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(EqualsTo expr) { + replaceFieldNameByFieldValue(expr); + } + + @Override + public void visit(LikeExpression expr) { + replaceFieldNameByFieldValue(expr); + } + + private void replaceFieldNameByFieldValue(BinaryExpression expr) { Expression leftExpression = expr.getLeftExpression(); Expression rightExpression = expr.getRightExpression(); if (!(rightExpression instanceof StringValue)) { @@ -37,10 +49,26 @@ public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter { Column leftColumnName = (Column) leftExpression; StringValue rightStringValue = (StringValue) rightExpression; + if (expr instanceof LikeExpression) { + String value = getValue(rightStringValue.getValue()); + rightStringValue.setValue(value); + } + Set fieldNames = fieldValueToFieldNames.get(rightStringValue.getValue()); if (!CollectionUtils.isEmpty(fieldNames) && !fieldNames.contains(leftColumnName.getColumnName())) { leftColumnName.setColumnName(fieldNames.stream().findFirst().get()); } } + private String getValue(String value) { + if (value.startsWith(PREFIX)) { + value = value.substring(1); + } + if (value.endsWith(PREFIX)) { + value = value.substring(0, value.length() - 1); + } + return value; + } + + } \ No newline at end of file diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index 9164176df..4a20eb7bb 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -18,6 +18,36 @@ import org.junit.jupiter.api.Test; */ class SqlParserUpdateHelperTest { + + @Test + void replaceFieldNameByValue() { + + String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + Map> fieldValueToFieldNames = new HashMap<>(); + fieldValueToFieldNames.put("邓紫棋", Collections.singleton("歌手名")); + + replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND " + + "歌手名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 like '%邓紫棋%' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + replaceSql = SqlParserUpdateHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 歌手名 LIKE '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " + + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + } + @Test void replaceFields() { @@ -340,7 +370,6 @@ class SqlParserUpdateHelperTest { + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and sum(pv) >1 " + "GROUP BY department order by pv desc limit 10"; replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); @@ -351,7 +380,6 @@ class SqlParserUpdateHelperTest { + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " + "GROUP BY department order by pv desc limit 10"; replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);