mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) If there are function operator fields, precise replacement must be performed. (#1554)
This commit is contained in:
@@ -158,10 +158,7 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
AzureModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER
|
||||
),
|
||||
Lists.newArrayList(AzureModelFactory.PROVIDER, QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, "llama_2_70b"
|
||||
@@ -172,9 +169,7 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(
|
||||
QianfanModelFactory.PROVIDER, DEMO
|
||||
)
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,12 +101,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
|
||||
private static List<Parameter.Dependency> getDimensionDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
EmbeddingStoreType.MILVUS.name()
|
||||
),
|
||||
ImmutableMap.of(
|
||||
EmbeddingStoreType.MILVUS.name(), "384"
|
||||
)
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -29,8 +29,7 @@ public enum AggregateEnum {
|
||||
}
|
||||
|
||||
public static Map<String, String> getAggregateEnum() {
|
||||
Map<String, String> aggregateMap = Arrays.stream(AggregateEnum.values())
|
||||
return Arrays.stream(AggregateEnum.values())
|
||||
.collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
|
||||
return aggregateMap;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,36 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
||||
private Map<String, String> fieldNameMap;
|
||||
private boolean exactReplace;
|
||||
private ThreadLocal<Boolean> exactReplace = ThreadLocal.withInitial(() -> false);
|
||||
|
||||
public FieldReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
|
||||
this.fieldNameMap = fieldNameMap;
|
||||
this.exactReplace = exactReplace;
|
||||
this.exactReplace.set(exactReplace);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(Column column) {
|
||||
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace);
|
||||
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace.get());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(Function function) {
|
||||
boolean originalExactReplace = exactReplace.get();
|
||||
exactReplace.set(true);
|
||||
try {
|
||||
super.visit(function);
|
||||
} finally {
|
||||
exactReplace.set(originalExactReplace);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -126,8 +126,6 @@ public class SqlReplaceHelper {
|
||||
if (!(selectStatement instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
//List<PlainSelect> plainSelectList = new ArrayList<>();
|
||||
//plainSelectList.add((PlainSelect) selectStatement);
|
||||
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelect(selectStatement);
|
||||
for (PlainSelect plainSelect : plainSelects) {
|
||||
Expression where = plainSelect.getWhere();
|
||||
@@ -186,18 +184,14 @@ public class SqlReplaceHelper {
|
||||
public static String replaceFields(String sql, Map<String, String> fieldNameMap, boolean exactReplace) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
|
||||
//plainSelectList.add(selectStatement.getPlainSelect());
|
||||
if (selectStatement instanceof PlainSelect) {
|
||||
PlainSelect plainSelect = (PlainSelect) selectStatement;
|
||||
plainSelectList.add(plainSelect);
|
||||
getFromSelect(plainSelect.getFromItem(), plainSelectList);
|
||||
//plainSelectList.add((PlainSelect) selectStatement);
|
||||
} else if (selectStatement instanceof SetOperationList) {
|
||||
SetOperationList setOperationList = (SetOperationList) selectStatement;
|
||||
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
|
||||
setOperationList.getSelects().forEach(subSelectBody -> {
|
||||
//PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
||||
//plainSelectList.add(subPlainSelect);
|
||||
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
||||
plainSelectList.add(subPlainSelect);
|
||||
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
|
||||
|
||||
@@ -252,8 +252,8 @@ class SqlReplaceHelperTest {
|
||||
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 "
|
||||
+ "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' "
|
||||
"SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 "
|
||||
+ "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14' "
|
||||
+ "GROUP BY YEAR(publish_date)",
|
||||
replaceSql);
|
||||
|
||||
@@ -265,8 +265,8 @@ class SqlReplaceHelperTest {
|
||||
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 "
|
||||
+ "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14'"
|
||||
"SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 "
|
||||
+ "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14'"
|
||||
+ " GROUP BY publish_date",
|
||||
replaceSql);
|
||||
|
||||
@@ -360,9 +360,9 @@ class SqlReplaceHelperTest {
|
||||
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT song_name, sum(user_id) FROM CSpider WHERE (1 < 2) AND "
|
||||
"SELECT song_name, sum(评分) FROM CSpider WHERE (1 < 2) AND "
|
||||
+ "sys_imp_date = '2023-10-15' GROUP BY song_name HAVING "
|
||||
+ "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql);
|
||||
+ "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", replaceSql);
|
||||
|
||||
replaceSql = "SELECT sum(评分)/ (SELECT sum(评分) FROM CSpider WHERE 数据日期 = '2023-10-15')"
|
||||
+ " FROM CSpider WHERE 数据日期 = '2023-10-15' "
|
||||
@@ -371,9 +371,21 @@ class SqlReplaceHelperTest {
|
||||
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT sum(user_id) / (SELECT sum(user_id) FROM CSpider WHERE sys_imp_date = '2023-10-15') "
|
||||
"SELECT sum(评分) / (SELECT sum(评分) FROM CSpider WHERE sys_imp_date = '2023-10-15') "
|
||||
+ "FROM CSpider WHERE sys_imp_date = '2023-10-15' GROUP BY song_name HAVING "
|
||||
+ "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql);
|
||||
+ "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testReplaceFunctionField() {
|
||||
Map<String, String> fieldToBizName = initParams();
|
||||
String replaceSql = "SELECT TIMESTAMPDIFF (MONTH,歌曲发布时间,CURDATE()) AS 发布月数 FROM 歌曲库 WHERE 歌手名 = '邓紫棋' ";
|
||||
|
||||
replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName);
|
||||
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
|
||||
Assert.assertEquals(
|
||||
"SELECT TIMESTAMPDIFF(MONTH, song_publis_date, CURDATE()) AS 发布月数 "
|
||||
+ "FROM 歌曲库 WHERE singer_name = '邓紫棋'", replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
Reference in New Issue
Block a user