diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAliasReplaceNameVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAliasReplaceNameVisitor.java new file mode 100644 index 000000000..91f391bc0 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAliasReplaceNameVisitor.java @@ -0,0 +1,41 @@ +package com.tencent.supersonic.common.jsqlparser; + +import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; + +import net.sf.jsqlparser.expression.Alias; + +public class FieldAliasReplaceNameVisitor extends SelectItemVisitorAdapter { + private Map fieldNameMap; + + private Map aliasToActualExpression = new HashMap<>(); + + public FieldAliasReplaceNameVisitor(Map fieldNameMap) { + this.fieldNameMap = fieldNameMap; + } + + @Override + public void visit(SelectItem selectExpressionItem) { + Alias alias = selectExpressionItem.getAlias(); + if (alias == null) { + return; + } + String aliasName = alias.getName(); + String replaceValue = fieldNameMap.get(aliasName); + if (StringUtils.isBlank(replaceValue)) { + return; + } + + aliasToActualExpression.put(aliasName, replaceValue); + alias.setName(replaceValue); + } + + public Map getAliasToActualExpression() { + return aliasToActualExpression; + } +} diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index d5443c328..7d4cd30f9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -449,6 +449,22 @@ public class SqlReplaceHelper { } } + public static String replaceAliasFieldName(String sql, Map fieldNameMap) { + Select selectStatement = SqlSelectHelper.getSelect(sql); + if (!(selectStatement instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectStatement; + FieldAliasReplaceNameVisitor visitor = new FieldAliasReplaceNameVisitor(fieldNameMap); + for (SelectItem selectItem : plainSelect.getSelectItems()) { + selectItem.accept(visitor); + } + Map aliasToActualExpression = visitor.getAliasToActualExpression(); + if (Objects.nonNull(aliasToActualExpression) && !aliasToActualExpression.isEmpty()) { + return replaceFields(selectStatement.toString(), aliasToActualExpression, true); + } + return selectStatement.toString(); + } public static String replaceAlias(String sql) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { diff --git a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java index 89bd8c7c1..4fc14af22 100644 --- a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java @@ -9,6 +9,7 @@ import dev.langchain4j.model.zhipu.ZhipuAiChatModel; import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel; import org.springframework.beans.factory.InitializingBean; import org.springframework.stereotype.Service; +import static java.time.Duration.ofSeconds; @Service public class ZhipuModelFactory implements ModelFactory, InitializingBean { @@ -30,7 +31,8 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean { public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) { return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl()) .apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName()) - .maxRetries(embeddingModelConfig.getMaxRetries()) + .maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60)) + .connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60)).readTimeout(ofSeconds(60)) .logRequests(embeddingModelConfig.getLogRequests()) .logResponses(embeddingModelConfig.getLogResponses()).build(); } diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java index 084978319..7bd5ee25e 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java @@ -324,6 +324,38 @@ class SqlReplaceHelperTest { replaceSql); } + @Test + void testReplaceAliasFieldName() { + Map map = new HashMap<>(); + map.put("总访问次数", "\"总访问次数\""); + map.put("访问次数", "\"访问次数\""); + String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; + String replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); + System.out.println(replaceSql); + Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE " + + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", + replaceSql); + + sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + + "group by 部门 order by 总访问次数 desc limit 10"; + replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); + System.out.println(replaceSql); + Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " + + "GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", replaceSql); + + sql = "select 部门, sum(访问次数) as 访问次数 from 超音数 where " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + + "group by 部门 order by 访问次数 desc limit 10"; + replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); + System.out.println(replaceSql); + Assert.assertEquals("SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, " + + "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY \"访问次数\" DESC LIMIT 10", + replaceSql); + } + @Test void testReplaceAggAliasOrderbyField() { String sql = "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 " diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/HanadbAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/HanadbAdaptor.java index 55aaf4f39..f16ca6871 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/HanadbAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/HanadbAdaptor.java @@ -7,7 +7,7 @@ public class HanadbAdaptor extends DefaultDbAdaptor { @Override public String rewriteSql(String sql) { - return sql.replaceAll("`", "\""); + return sql.replaceAll("`(.*?)`", "\"$1\"").replaceAll("\"([A-Z0-9_]+?)\"", "$1"); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java index cee62013d..9ce495cff 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.core.translator.converter; +import com.tencent.supersonic.common.jsqlparser.SqlAsHelper; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; @@ -78,6 +79,9 @@ public class SqlQueryConverter implements QueryConverter { generateDerivedMetric(sqlGenerateUtils, queryStatement); queryStatement.setSql(sqlQueryParam.getSql()); + // replace sql fields for db, must called after convertNameToBizName + String sqlRewrite = replaceSqlFieldsForHanaDB(queryStatement, sqlQueryParam.getSql()); + sqlQueryParam.setSql(sqlRewrite); log.info("parse sqlQuery [{}] ", sqlQueryParam); } @@ -224,6 +228,54 @@ public class SqlQueryConverter implements QueryConverter { } + /** + * special process for hanaDB,the sap hana DB don't support the chinese name as + * the column name, + * so we need to quote the column name after converting the convertNameToBizName + * called + * + * sap hana DB will auto translate the colume to upper case letter if not + * quoted. + * also we need to quote the field name if it is a lower case letter. + * + * @param queryStatement + * @param sql + * @return + */ + private String replaceSqlFieldsForHanaDB(QueryStatement queryStatement, String sql) { + SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp(); + if (!semanticSchemaResp.getDatabaseResp().getType().equalsIgnoreCase(EngineType.HANADB.getName())) { + return sql; + } + Map fieldNameToBizNameMap = getFieldNameToBizNameMap(semanticSchemaResp); + + Map fieldNameToBizNameMapQuote = new HashMap<>(); + fieldNameToBizNameMap.forEach((key, value) -> { + if (!fieldNameToBizNameMapQuote.containsKey(value) && !value.matches("\".*\"") + && !value.matches("[A-Z0-9_].*?")) { + fieldNameToBizNameMapQuote.put(value, "\"" + value + "\""); + } + }); + String sqlNew = sql; + if (fieldNameToBizNameMapQuote.size() > 0) { + sqlNew = SqlReplaceHelper.replaceFields(sql, fieldNameToBizNameMapQuote, true); + } + // replace alias field name + List asFields = SqlAsHelper.getAsFields(sqlNew); + Map fieldMapput = new HashMap<>(); + for (String asField : asFields) { + String value = asField; + if (!value.matches("\".*?\"") && !value.matches("[A-Z0-9_].*?")) { + value = "\"" + asField + "\""; + fieldMapput.put(asField, value); + } + } + if (fieldMapput.size() > 0) { + sqlNew = SqlReplaceHelper.replaceAliasFieldName(sqlNew, fieldMapput); + } + return sqlNew; + } + private void convertNameToBizName(QueryStatement queryStatement) { SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp(); Map fieldNameToBizNameMap = getFieldNameToBizNameMap(semanticSchemaResp); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/node/SemanticNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/node/SemanticNode.java index 88f2a7d80..847aacb68 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/node/SemanticNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/calcite/node/SemanticNode.java @@ -80,9 +80,13 @@ public abstract class SemanticNode { scope.getValidator().getCatalogReader().getRootSchema(), engineType); if (Configuration.getSqlAdvisor(sqlValidatorWithHints, engineType).getReservedAndKeyWords() .contains(expression.toUpperCase())) { - expression = String.format("`%s`", expression); + if (engineType == EngineType.HANADB) { + expression = String.format("\"%s\"", expression); + } else { + expression = String.format("`%s`", expression); + } } - SqlParser sqlParser = + SqlParser sqlParser = SqlParser.create(expression, Configuration.getParserConfig(engineType)); SqlNode sqlNode = sqlParser.parseExpression(); scope.validateExpr(sqlNode);