diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelDetail.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelDetail.java index 958272840..28b7c1092 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelDetail.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelDetail.java @@ -27,6 +27,8 @@ public class ModelDetail { private List fields = Lists.newArrayList(); + private List sqlVariables = Lists.newArrayList(); + public String getSqlQuery() { if (StringUtils.isNotBlank(sqlQuery) && sqlQuery.endsWith(";")) { sqlQuery = sqlQuery.substring(0, sqlQuery.length() - 1); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlVariable.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlVariable.java new file mode 100644 index 000000000..bdee9e5dd --- /dev/null +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlVariable.java @@ -0,0 +1,16 @@ +package com.tencent.supersonic.headless.api.pojo; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType; +import lombok.Data; + +import java.util.List; + +@Data +public class SqlVariable { + private String name; + private VariableValueType valueType; + private List defaultValues = Lists.newArrayList(); +} + + diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/VariableValueType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/VariableValueType.java new file mode 100644 index 000000000..59fdd33fc --- /dev/null +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/VariableValueType.java @@ -0,0 +1,7 @@ +package com.tencent.supersonic.headless.api.pojo.enums; + +public enum VariableValueType { + STRING, + NUMBER, + EXPR +} diff --git a/headless/core/pom.xml b/headless/core/pom.xml index be65413f8..1b4cf3edf 100644 --- a/headless/core/pom.xml +++ b/headless/core/pom.xml @@ -24,6 +24,11 @@ ${lombok.version} provided + + org.antlr + ST4 + ${st.version} + org.springframework diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlVariableParseUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlVariableParseUtils.java new file mode 100644 index 000000000..c1db3ac93 --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlVariableParseUtils.java @@ -0,0 +1,120 @@ +package com.tencent.supersonic.headless.core.utils; + +import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; +import com.tencent.supersonic.headless.api.pojo.Param; +import com.tencent.supersonic.headless.api.pojo.SqlVariable; +import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; +import org.stringtemplate.v4.ST; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import static com.tencent.supersonic.common.pojo.Constants.COMMA; +import static com.tencent.supersonic.common.pojo.Constants.EMPTY; + +@Slf4j +public class SqlVariableParseUtils { + + public static final String REG_SENSITIVE_SQL = "drop\\s|alter\\s|grant\\s|insert\\s|replace\\s|delete\\s|" + + "truncate\\s|update\\s|remove\\s"; + public static final Pattern PATTERN_SENSITIVE_SQL = Pattern.compile(REG_SENSITIVE_SQL); + + public static final String APOSTROPHE = "'"; + + private static final char delimiter = '$'; + + public static String parse(String sql, List sqlVariables, List params) { + if (CollectionUtils.isEmpty(sqlVariables)) { + return sql; + } + Map queryParams = new HashMap<>(); + //1. handle default variable value + sqlVariables.forEach(variable -> { + queryParams.put(variable.getName().trim(), + getValues(variable.getValueType(), variable.getDefaultValues())); + }); + + //override by variable param + if (!CollectionUtils.isEmpty(params)) { + Map> map = + sqlVariables.stream().collect(Collectors.groupingBy(SqlVariable::getName)); + params.forEach(p -> { + if (map.containsKey(p.getName())) { + List list = map.get(p.getName()); + if (!CollectionUtils.isEmpty(list)) { + SqlVariable v = list.get(list.size() - 1); + queryParams.put(p.getName().trim(), getValue(v.getValueType(), p.getValue())); + } + } + }); + } + + queryParams.forEach((k, v) -> { + if (v instanceof List && ((List) v).size() > 0) { + v = ((List) v).stream().collect(Collectors.joining(COMMA)).toString(); + } + queryParams.put(k, v); + }); + ST st = new ST(sql, delimiter, delimiter); + if (!CollectionUtils.isEmpty(queryParams)) { + queryParams.forEach(st::add); + } + return st.render(); + } + + public static List getValues(VariableValueType valueType, List values) { + if (CollectionUtils.isEmpty(values)) { + return new ArrayList<>(); + } + if (null != valueType) { + switch (valueType) { + case STRING: + return values.stream().map(String::valueOf) + .map(s -> s.startsWith(APOSTROPHE) && s.endsWith(APOSTROPHE) + ? s : String.join(EMPTY, APOSTROPHE, s, APOSTROPHE)) + .collect(Collectors.toList()); + case EXPR: + values.stream().map(String::valueOf).forEach(SqlVariableParseUtils::checkSensitiveSql); + return values.stream().map(String::valueOf).collect(Collectors.toList()); + case NUMBER: + return values.stream().map(String::valueOf).collect(Collectors.toList()); + default: + return values.stream().map(String::valueOf).collect(Collectors.toList()); + } + } + return values.stream().map(String::valueOf).collect(Collectors.toList()); + } + + public static Object getValue(VariableValueType valueType, String value) { + if (!StringUtils.isEmpty(value)) { + if (null != valueType) { + switch (valueType) { + case STRING: + return String.join(EMPTY, value.startsWith(APOSTROPHE) ? EMPTY : APOSTROPHE, + value, value.endsWith(APOSTROPHE) ? EMPTY : APOSTROPHE); + case NUMBER: + case EXPR: + default: + return value; + } + } + } + return value; + } + + public static void checkSensitiveSql(String sql) { + Matcher matcher = PATTERN_SENSITIVE_SQL.matcher(sql.toLowerCase()); + if (matcher.find()) { + String group = matcher.group(); + log.warn("Sensitive SQL operations are not allowed: {}", group.toUpperCase()); + throw new InvalidArgumentException("Sensitive SQL operations are not allowed: " + group.toUpperCase()); + } + } + +} diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java new file mode 100644 index 000000000..c9e119082 --- /dev/null +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java @@ -0,0 +1,71 @@ +package com.tencent.supersonic.headless.server.utils; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.headless.api.pojo.Param; +import com.tencent.supersonic.headless.api.pojo.SqlVariable; +import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType; +import com.tencent.supersonic.headless.core.utils.SqlVariableParseUtils; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import java.util.List; + +public class SqlVariableParseUtilsTest { + + @Test + void testParseSql_defaultVariableValue() { + String sql = "select * from t_$interval$ where id = $id$ and name = $name$"; + List variables = Lists.newArrayList(mockNumSqlVariable(), + mockExprSqlVariable(), mockStrSqlVariable()); + String actualSql = SqlVariableParseUtils.parse(sql, variables, Lists.newArrayList()); + String expectedSql = "select * from t_d where id = 1 and name = 'tom'"; + Assertions.assertEquals(expectedSql, actualSql); + } + + @Test + void testParseSql() { + String sql = "select * from t_$interval$ where id = $id$ and name = $name$"; + List variables = Lists.newArrayList(mockNumSqlVariable(), + mockExprSqlVariable(), mockStrSqlVariable()); + List params = Lists.newArrayList(mockIdParam(), mockNameParam(), mockIntervalParam()); + String actualSql = SqlVariableParseUtils.parse(sql, variables, params); + String expectedSql = "select * from t_wk where id = 2 and name = 'alice'"; + Assertions.assertEquals(expectedSql, actualSql); + } + + private SqlVariable mockNumSqlVariable() { + return mockSqlVariable("id", VariableValueType.NUMBER, 1); + } + + private SqlVariable mockStrSqlVariable() { + return mockSqlVariable("name", VariableValueType.STRING, "tom"); + } + + private SqlVariable mockExprSqlVariable() { + return mockSqlVariable("interval", VariableValueType.EXPR, "d"); + } + + private SqlVariable mockSqlVariable(String name, VariableValueType variableValueType, Object value) { + SqlVariable sqlVariable = new SqlVariable(); + sqlVariable.setName(name); + sqlVariable.setValueType(variableValueType); + sqlVariable.setDefaultValues(Lists.newArrayList(value)); + return sqlVariable; + } + + private Param mockIdParam() { + return mockParam("id", "2"); + } + + private Param mockNameParam() { + return mockParam("name", "alice"); + } + + private Param mockIntervalParam() { + return mockParam("interval", "wk"); + } + + private Param mockParam(String name, String value) { + return new Param(name, value); + } + +} diff --git a/pom.xml b/pom.xml index 1c012b092..66e108049 100644 --- a/pom.xml +++ b/pom.xml @@ -74,6 +74,7 @@ 3.17 0.24.0 42.7.1 + 4.0.8