(improvement)(chat) If there are function operator fields, precise replacement must be performed. (#1554)

This commit is contained in:
lexluo09
2024-08-12 12:45:24 +08:00
committed by GitHub
parent 1ff4a71a41
commit 0c70df12ca
6 changed files with 42 additions and 34 deletions

View File

@@ -158,10 +158,7 @@ public class ChatModelParameterConfig extends ParameterConfig {
private static List<Parameter.Dependency> getEndpointDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(
AzureModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER
),
Lists.newArrayList(AzureModelFactory.PROVIDER, QianfanModelFactory.PROVIDER),
ImmutableMap.of(
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, "llama_2_70b"
@@ -172,9 +169,7 @@ public class ChatModelParameterConfig extends ParameterConfig {
private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(
QianfanModelFactory.PROVIDER, DEMO
)
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
);
}
}

View File

@@ -101,12 +101,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
private static List<Parameter.Dependency> getDimensionDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "384"
)
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384")
);
}
}

View File

@@ -29,8 +29,7 @@ public enum AggregateEnum {
}
public static Map<String, String> getAggregateEnum() {
Map<String, String> aggregateMap = Arrays.stream(AggregateEnum.values())
return Arrays.stream(AggregateEnum.values())
.collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
return aggregateMap;
}
}

View File

@@ -1,24 +1,36 @@
package com.tencent.supersonic.common.jsqlparser;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.schema.Column;
import java.util.Map;
@Slf4j
public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldNameMap;
private boolean exactReplace;
private ThreadLocal<Boolean> exactReplace = ThreadLocal.withInitial(() -> false);
public FieldReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldNameMap = fieldNameMap;
this.exactReplace = exactReplace;
this.exactReplace.set(exactReplace);
}
@Override
public void visit(Column column) {
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace);
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace.get());
}
@Override
public void visit(Function function) {
boolean originalExactReplace = exactReplace.get();
exactReplace.set(true);
try {
super.visit(function);
} finally {
exactReplace.set(originalExactReplace);
}
}
}

View File

@@ -126,8 +126,6 @@ public class SqlReplaceHelper {
if (!(selectStatement instanceof PlainSelect)) {
return sql;
}
//List<PlainSelect> plainSelectList = new ArrayList<>();
//plainSelectList.add((PlainSelect) selectStatement);
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelect(selectStatement);
for (PlainSelect plainSelect : plainSelects) {
Expression where = plainSelect.getWhere();
@@ -186,18 +184,14 @@ public class SqlReplaceHelper {
public static String replaceFields(String sql, Map<String, String> fieldNameMap, boolean exactReplace) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
//plainSelectList.add(selectStatement.getPlainSelect());
if (selectStatement instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectStatement;
plainSelectList.add(plainSelect);
getFromSelect(plainSelect.getFromItem(), plainSelectList);
//plainSelectList.add((PlainSelect) selectStatement);
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList.getSelects().forEach(subSelectBody -> {
//PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
//plainSelectList.add(subPlainSelect);
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);

View File

@@ -252,8 +252,8 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 "
+ "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' "
"SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 "
+ "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14' "
+ "GROUP BY YEAR(publish_date)",
replaceSql);
@@ -265,8 +265,8 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 "
+ "WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14'"
"SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 "
+ "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14'"
+ " GROUP BY publish_date",
replaceSql);
@@ -360,9 +360,9 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT song_name, sum(user_id) FROM CSpider WHERE (1 < 2) AND "
"SELECT song_name, sum(评分) FROM CSpider WHERE (1 < 2) AND "
+ "sys_imp_date = '2023-10-15' GROUP BY song_name HAVING "
+ "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql);
+ "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", replaceSql);
replaceSql = "SELECT sum(评分)/ (SELECT sum(评分) FROM CSpider WHERE 数据日期 = '2023-10-15')"
+ " FROM CSpider WHERE 数据日期 = '2023-10-15' "
@@ -371,9 +371,21 @@ class SqlReplaceHelperTest {
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT sum(user_id) / (SELECT sum(user_id) FROM CSpider WHERE sys_imp_date = '2023-10-15') "
"SELECT sum(评分) / (SELECT sum(评分) FROM CSpider WHERE sys_imp_date = '2023-10-15') "
+ "FROM CSpider WHERE sys_imp_date = '2023-10-15' GROUP BY song_name HAVING "
+ "sum(user_id) < (SELECT min(user_id) FROM CSpider WHERE user_id = '英文')", replaceSql);
+ "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", replaceSql);
}
@Test
void testReplaceFunctionField() {
Map<String, String> fieldToBizName = initParams();
String replaceSql = "SELECT TIMESTAMPDIFF (MONTH,歌曲发布时间,CURDATE()) AS 发布月数 FROM 歌曲库 WHERE 歌手名 = '邓紫棋' ";
replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName);
replaceSql = SqlReplaceHelper.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT TIMESTAMPDIFF(MONTH, song_publis_date, CURDATE()) AS 发布月数 "
+ "FROM 歌曲库 WHERE singer_name = '邓紫棋'", replaceSql);
}
@Test