diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java index fb8bc568a..2b3f743dd 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java @@ -158,10 +158,7 @@ public class ChatModelParameterConfig extends ParameterConfig { private static List getEndpointDependency() { return getDependency(CHAT_MODEL_PROVIDER.getName(), - Lists.newArrayList( - AzureModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER - ), + Lists.newArrayList(AzureModelFactory.PROVIDER, QianfanModelFactory.PROVIDER), ImmutableMap.of( AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, QianfanModelFactory.PROVIDER, "llama_2_70b" @@ -172,9 +169,7 @@ public class ChatModelParameterConfig extends ParameterConfig { private static List getSecretKeyDependency() { return getDependency(CHAT_MODEL_PROVIDER.getName(), Lists.newArrayList(QianfanModelFactory.PROVIDER), - ImmutableMap.of( - QianfanModelFactory.PROVIDER, DEMO - ) + ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO) ); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java index 18bdb2913..94733e445 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java @@ -101,12 +101,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { private static List getDimensionDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList( - EmbeddingStoreType.MILVUS.name() - ), - ImmutableMap.of( - EmbeddingStoreType.MILVUS.name(), "384" - ) + Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), + ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384") ); } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/AggregateEnum.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/AggregateEnum.java index 319fbf63f..12a595813 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/AggregateEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/AggregateEnum.java @@ -29,8 +29,7 @@ public enum AggregateEnum { } public static Map getAggregateEnum() { - Map aggregateMap = Arrays.stream(AggregateEnum.values()) + return Arrays.stream(AggregateEnum.values()) .collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN)); - return aggregateMap; } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldReplaceVisitor.java index 5ef1fefab..66ba6a084 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldReplaceVisitor.java @@ -1,24 +1,36 @@ package com.tencent.supersonic.common.jsqlparser; -import java.util.Map; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; +import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.schema.Column; +import java.util.Map; + @Slf4j public class FieldReplaceVisitor extends ExpressionVisitorAdapter { - ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); private Map fieldNameMap; - private boolean exactReplace; + private ThreadLocal exactReplace = ThreadLocal.withInitial(() -> false); public FieldReplaceVisitor(Map fieldNameMap, boolean exactReplace) { this.fieldNameMap = fieldNameMap; - this.exactReplace = exactReplace; + this.exactReplace.set(exactReplace); } @Override public void visit(Column column) { - parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace); + parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace.get()); + } + + @Override + public void visit(Function function) { + boolean originalExactReplace = exactReplace.get(); + exactReplace.set(true); + try { + super.visit(function); + } finally { + exactReplace.set(originalExactReplace); + } } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index ac2b607ea..108156aae 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -126,8 +126,6 @@ public class SqlReplaceHelper { if (!(selectStatement instanceof PlainSelect)) { return sql; } - //List plainSelectList = new ArrayList<>(); - //plainSelectList.add((PlainSelect) selectStatement); List plainSelects = SqlSelectHelper.getPlainSelect(selectStatement); for (PlainSelect plainSelect : plainSelects) { Expression where = plainSelect.getWhere(); @@ -186,18 +184,14 @@ public class SqlReplaceHelper { public static String replaceFields(String sql, Map fieldNameMap, boolean exactReplace) { Select selectStatement = SqlSelectHelper.getSelect(sql); List plainSelectList = SqlSelectHelper.getWithItem(selectStatement); - //plainSelectList.add(selectStatement.getPlainSelect()); if (selectStatement instanceof PlainSelect) { PlainSelect plainSelect = (PlainSelect) selectStatement; plainSelectList.add(plainSelect); getFromSelect(plainSelect.getFromItem(), plainSelectList); - //plainSelectList.add((PlainSelect) selectStatement); } else if (selectStatement instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectStatement; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { setOperationList.getSelects().forEach(subSelectBody -> { - //PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - //plainSelectList.add(subPlainSelect); PlainSelect subPlainSelect = (PlainSelect) subSelectBody; plainSelectList.add(subPlainSelect); getFromSelect(subPlainSelect.getFromItem(), plainSelectList); diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java index 3b602dfa6..218eb35ee 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java @@ -252,8 +252,8 @@ class SqlReplaceHelperTest { replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( - "SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 " - + "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' " + "SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 " + + "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14' " + "GROUP BY YEAR(publish_date)", replaceSql); @@ -265,8 +265,8 @@ class SqlReplaceHelperTest { replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( - "SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 " - + "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14'" + "SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 " + + "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14'" + " GROUP BY publish_date", replaceSql); @@ -360,9 +360,9 @@ class SqlReplaceHelperTest { replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( - "SELECT song_name, sum(user_id) FROM CSpider WHERE (1 < 2) AND " + "SELECT song_name, sum(评分) FROM CSpider WHERE (1 < 2) AND " + "sys_imp_date = '2023-10-15' GROUP BY song_name HAVING " - + "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql); + + "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", replaceSql); replaceSql = "SELECT sum(评分)/ (SELECT sum(评分) FROM CSpider WHERE 数据日期 = '2023-10-15')" + " FROM CSpider WHERE 数据日期 = '2023-10-15' " @@ -371,9 +371,21 @@ class SqlReplaceHelperTest { replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( - "SELECT sum(user_id) / (SELECT sum(user_id) FROM CSpider WHERE sys_imp_date = '2023-10-15') " + "SELECT sum(评分) / (SELECT sum(评分) FROM CSpider WHERE sys_imp_date = '2023-10-15') " + "FROM CSpider WHERE sys_imp_date = '2023-10-15' GROUP BY song_name HAVING " - + "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql); + + "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", replaceSql); + } + + @Test + void testReplaceFunctionField() { + Map fieldToBizName = initParams(); + String replaceSql = "SELECT TIMESTAMPDIFF (MONTH,歌曲发布时间,CURDATE()) AS 发布月数 FROM 歌曲库 WHERE 歌手名 = '邓紫棋' "; + + replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); + replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); + Assert.assertEquals( + "SELECT TIMESTAMPDIFF(MONTH, song_publis_date, CURDATE()) AS 发布月数 " + + "FROM 歌曲库 WHERE singer_name = '邓紫棋'", replaceSql); } @Test