mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(headless) add sql variable for model sql (#740)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -27,6 +27,8 @@ public class ModelDetail {
|
||||
|
||||
private List<Field> fields = Lists.newArrayList();
|
||||
|
||||
private List<SqlVariable> sqlVariables = Lists.newArrayList();
|
||||
|
||||
public String getSqlQuery() {
|
||||
if (StringUtils.isNotBlank(sqlQuery) && sqlQuery.endsWith(";")) {
|
||||
sqlQuery = sqlQuery.substring(0, sqlQuery.length() - 1);
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class SqlVariable {
|
||||
private String name;
|
||||
private VariableValueType valueType;
|
||||
private List<Object> defaultValues = Lists.newArrayList();
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||
|
||||
public enum VariableValueType {
|
||||
STRING,
|
||||
NUMBER,
|
||||
EXPR
|
||||
}
|
||||
@@ -24,6 +24,11 @@
|
||||
<version>${lombok.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.antlr</groupId>
|
||||
<artifactId>ST4</artifactId>
|
||||
<version>${st.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
package com.tencent.supersonic.headless.core.utils;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.headless.api.pojo.Param;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlVariable;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.stringtemplate.v4.ST;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.EMPTY;
|
||||
|
||||
@Slf4j
|
||||
public class SqlVariableParseUtils {
|
||||
|
||||
public static final String REG_SENSITIVE_SQL = "drop\\s|alter\\s|grant\\s|insert\\s|replace\\s|delete\\s|"
|
||||
+ "truncate\\s|update\\s|remove\\s";
|
||||
public static final Pattern PATTERN_SENSITIVE_SQL = Pattern.compile(REG_SENSITIVE_SQL);
|
||||
|
||||
public static final String APOSTROPHE = "'";
|
||||
|
||||
private static final char delimiter = '$';
|
||||
|
||||
public static String parse(String sql, List<SqlVariable> sqlVariables, List<Param> params) {
|
||||
if (CollectionUtils.isEmpty(sqlVariables)) {
|
||||
return sql;
|
||||
}
|
||||
Map<String, Object> queryParams = new HashMap<>();
|
||||
//1. handle default variable value
|
||||
sqlVariables.forEach(variable -> {
|
||||
queryParams.put(variable.getName().trim(),
|
||||
getValues(variable.getValueType(), variable.getDefaultValues()));
|
||||
});
|
||||
|
||||
//override by variable param
|
||||
if (!CollectionUtils.isEmpty(params)) {
|
||||
Map<String, List<SqlVariable>> map =
|
||||
sqlVariables.stream().collect(Collectors.groupingBy(SqlVariable::getName));
|
||||
params.forEach(p -> {
|
||||
if (map.containsKey(p.getName())) {
|
||||
List<SqlVariable> list = map.get(p.getName());
|
||||
if (!CollectionUtils.isEmpty(list)) {
|
||||
SqlVariable v = list.get(list.size() - 1);
|
||||
queryParams.put(p.getName().trim(), getValue(v.getValueType(), p.getValue()));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
queryParams.forEach((k, v) -> {
|
||||
if (v instanceof List && ((List) v).size() > 0) {
|
||||
v = ((List) v).stream().collect(Collectors.joining(COMMA)).toString();
|
||||
}
|
||||
queryParams.put(k, v);
|
||||
});
|
||||
ST st = new ST(sql, delimiter, delimiter);
|
||||
if (!CollectionUtils.isEmpty(queryParams)) {
|
||||
queryParams.forEach(st::add);
|
||||
}
|
||||
return st.render();
|
||||
}
|
||||
|
||||
public static List<String> getValues(VariableValueType valueType, List<Object> values) {
|
||||
if (CollectionUtils.isEmpty(values)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
if (null != valueType) {
|
||||
switch (valueType) {
|
||||
case STRING:
|
||||
return values.stream().map(String::valueOf)
|
||||
.map(s -> s.startsWith(APOSTROPHE) && s.endsWith(APOSTROPHE)
|
||||
? s : String.join(EMPTY, APOSTROPHE, s, APOSTROPHE))
|
||||
.collect(Collectors.toList());
|
||||
case EXPR:
|
||||
values.stream().map(String::valueOf).forEach(SqlVariableParseUtils::checkSensitiveSql);
|
||||
return values.stream().map(String::valueOf).collect(Collectors.toList());
|
||||
case NUMBER:
|
||||
return values.stream().map(String::valueOf).collect(Collectors.toList());
|
||||
default:
|
||||
return values.stream().map(String::valueOf).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
return values.stream().map(String::valueOf).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static Object getValue(VariableValueType valueType, String value) {
|
||||
if (!StringUtils.isEmpty(value)) {
|
||||
if (null != valueType) {
|
||||
switch (valueType) {
|
||||
case STRING:
|
||||
return String.join(EMPTY, value.startsWith(APOSTROPHE) ? EMPTY : APOSTROPHE,
|
||||
value, value.endsWith(APOSTROPHE) ? EMPTY : APOSTROPHE);
|
||||
case NUMBER:
|
||||
case EXPR:
|
||||
default:
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
public static void checkSensitiveSql(String sql) {
|
||||
Matcher matcher = PATTERN_SENSITIVE_SQL.matcher(sql.toLowerCase());
|
||||
if (matcher.find()) {
|
||||
String group = matcher.group();
|
||||
log.warn("Sensitive SQL operations are not allowed: {}", group.toUpperCase());
|
||||
throw new InvalidArgumentException("Sensitive SQL operations are not allowed: " + group.toUpperCase());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package com.tencent.supersonic.headless.server.utils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.Param;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlVariable;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
|
||||
import com.tencent.supersonic.headless.core.utils.SqlVariableParseUtils;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import java.util.List;
|
||||
|
||||
public class SqlVariableParseUtilsTest {
|
||||
|
||||
@Test
|
||||
void testParseSql_defaultVariableValue() {
|
||||
String sql = "select * from t_$interval$ where id = $id$ and name = $name$";
|
||||
List<SqlVariable> variables = Lists.newArrayList(mockNumSqlVariable(),
|
||||
mockExprSqlVariable(), mockStrSqlVariable());
|
||||
String actualSql = SqlVariableParseUtils.parse(sql, variables, Lists.newArrayList());
|
||||
String expectedSql = "select * from t_d where id = 1 and name = 'tom'";
|
||||
Assertions.assertEquals(expectedSql, actualSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testParseSql() {
|
||||
String sql = "select * from t_$interval$ where id = $id$ and name = $name$";
|
||||
List<SqlVariable> variables = Lists.newArrayList(mockNumSqlVariable(),
|
||||
mockExprSqlVariable(), mockStrSqlVariable());
|
||||
List<Param> params = Lists.newArrayList(mockIdParam(), mockNameParam(), mockIntervalParam());
|
||||
String actualSql = SqlVariableParseUtils.parse(sql, variables, params);
|
||||
String expectedSql = "select * from t_wk where id = 2 and name = 'alice'";
|
||||
Assertions.assertEquals(expectedSql, actualSql);
|
||||
}
|
||||
|
||||
private SqlVariable mockNumSqlVariable() {
|
||||
return mockSqlVariable("id", VariableValueType.NUMBER, 1);
|
||||
}
|
||||
|
||||
private SqlVariable mockStrSqlVariable() {
|
||||
return mockSqlVariable("name", VariableValueType.STRING, "tom");
|
||||
}
|
||||
|
||||
private SqlVariable mockExprSqlVariable() {
|
||||
return mockSqlVariable("interval", VariableValueType.EXPR, "d");
|
||||
}
|
||||
|
||||
private SqlVariable mockSqlVariable(String name, VariableValueType variableValueType, Object value) {
|
||||
SqlVariable sqlVariable = new SqlVariable();
|
||||
sqlVariable.setName(name);
|
||||
sqlVariable.setValueType(variableValueType);
|
||||
sqlVariable.setDefaultValues(Lists.newArrayList(value));
|
||||
return sqlVariable;
|
||||
}
|
||||
|
||||
private Param mockIdParam() {
|
||||
return mockParam("id", "2");
|
||||
}
|
||||
|
||||
private Param mockNameParam() {
|
||||
return mockParam("name", "alice");
|
||||
}
|
||||
|
||||
private Param mockIntervalParam() {
|
||||
return mockParam("interval", "wk");
|
||||
}
|
||||
|
||||
private Param mockParam(String name, String value) {
|
||||
return new Param(name, value);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user