(improvement)(headless) add sql variable for model sql (#740)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-02-23 11:00:06 +08:00
committed by GitHub
parent e95a528219
commit d8043c356f
7 changed files with 222 additions and 0 deletions

View File

@@ -27,6 +27,8 @@ public class ModelDetail {
private List<Field> fields = Lists.newArrayList();
private List<SqlVariable> sqlVariables = Lists.newArrayList();
public String getSqlQuery() {
if (StringUtils.isNotBlank(sqlQuery) && sqlQuery.endsWith(";")) {
sqlQuery = sqlQuery.substring(0, sqlQuery.length() - 1);

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.api.pojo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
import lombok.Data;
import java.util.List;
@Data
public class SqlVariable {
private String name;
private VariableValueType valueType;
private List<Object> defaultValues = Lists.newArrayList();
}

View File

@@ -0,0 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum VariableValueType {
STRING,
NUMBER,
EXPR
}

View File

@@ -24,6 +24,11 @@
<version>${lombok.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>ST4</artifactId>
<version>${st.version}</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>

View File

@@ -0,0 +1,120 @@
package com.tencent.supersonic.headless.core.utils;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.Param;
import com.tencent.supersonic.headless.api.pojo.SqlVariable;
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import org.stringtemplate.v4.ST;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
import static com.tencent.supersonic.common.pojo.Constants.EMPTY;
@Slf4j
public class SqlVariableParseUtils {
public static final String REG_SENSITIVE_SQL = "drop\\s|alter\\s|grant\\s|insert\\s|replace\\s|delete\\s|"
+ "truncate\\s|update\\s|remove\\s";
public static final Pattern PATTERN_SENSITIVE_SQL = Pattern.compile(REG_SENSITIVE_SQL);
public static final String APOSTROPHE = "'";
private static final char delimiter = '$';
public static String parse(String sql, List<SqlVariable> sqlVariables, List<Param> params) {
if (CollectionUtils.isEmpty(sqlVariables)) {
return sql;
}
Map<String, Object> queryParams = new HashMap<>();
//1. handle default variable value
sqlVariables.forEach(variable -> {
queryParams.put(variable.getName().trim(),
getValues(variable.getValueType(), variable.getDefaultValues()));
});
//override by variable param
if (!CollectionUtils.isEmpty(params)) {
Map<String, List<SqlVariable>> map =
sqlVariables.stream().collect(Collectors.groupingBy(SqlVariable::getName));
params.forEach(p -> {
if (map.containsKey(p.getName())) {
List<SqlVariable> list = map.get(p.getName());
if (!CollectionUtils.isEmpty(list)) {
SqlVariable v = list.get(list.size() - 1);
queryParams.put(p.getName().trim(), getValue(v.getValueType(), p.getValue()));
}
}
});
}
queryParams.forEach((k, v) -> {
if (v instanceof List && ((List) v).size() > 0) {
v = ((List) v).stream().collect(Collectors.joining(COMMA)).toString();
}
queryParams.put(k, v);
});
ST st = new ST(sql, delimiter, delimiter);
if (!CollectionUtils.isEmpty(queryParams)) {
queryParams.forEach(st::add);
}
return st.render();
}
public static List<String> getValues(VariableValueType valueType, List<Object> values) {
if (CollectionUtils.isEmpty(values)) {
return new ArrayList<>();
}
if (null != valueType) {
switch (valueType) {
case STRING:
return values.stream().map(String::valueOf)
.map(s -> s.startsWith(APOSTROPHE) && s.endsWith(APOSTROPHE)
? s : String.join(EMPTY, APOSTROPHE, s, APOSTROPHE))
.collect(Collectors.toList());
case EXPR:
values.stream().map(String::valueOf).forEach(SqlVariableParseUtils::checkSensitiveSql);
return values.stream().map(String::valueOf).collect(Collectors.toList());
case NUMBER:
return values.stream().map(String::valueOf).collect(Collectors.toList());
default:
return values.stream().map(String::valueOf).collect(Collectors.toList());
}
}
return values.stream().map(String::valueOf).collect(Collectors.toList());
}
public static Object getValue(VariableValueType valueType, String value) {
if (!StringUtils.isEmpty(value)) {
if (null != valueType) {
switch (valueType) {
case STRING:
return String.join(EMPTY, value.startsWith(APOSTROPHE) ? EMPTY : APOSTROPHE,
value, value.endsWith(APOSTROPHE) ? EMPTY : APOSTROPHE);
case NUMBER:
case EXPR:
default:
return value;
}
}
}
return value;
}
public static void checkSensitiveSql(String sql) {
Matcher matcher = PATTERN_SENSITIVE_SQL.matcher(sql.toLowerCase());
if (matcher.find()) {
String group = matcher.group();
log.warn("Sensitive SQL operations are not allowed: {}", group.toUpperCase());
throw new InvalidArgumentException("Sensitive SQL operations are not allowed: " + group.toUpperCase());
}
}
}

View File

@@ -0,0 +1,71 @@
package com.tencent.supersonic.headless.server.utils;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.Param;
import com.tencent.supersonic.headless.api.pojo.SqlVariable;
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
import com.tencent.supersonic.headless.core.utils.SqlVariableParseUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
public class SqlVariableParseUtilsTest {
@Test
void testParseSql_defaultVariableValue() {
String sql = "select * from t_$interval$ where id = $id$ and name = $name$";
List<SqlVariable> variables = Lists.newArrayList(mockNumSqlVariable(),
mockExprSqlVariable(), mockStrSqlVariable());
String actualSql = SqlVariableParseUtils.parse(sql, variables, Lists.newArrayList());
String expectedSql = "select * from t_d where id = 1 and name = 'tom'";
Assertions.assertEquals(expectedSql, actualSql);
}
@Test
void testParseSql() {
String sql = "select * from t_$interval$ where id = $id$ and name = $name$";
List<SqlVariable> variables = Lists.newArrayList(mockNumSqlVariable(),
mockExprSqlVariable(), mockStrSqlVariable());
List<Param> params = Lists.newArrayList(mockIdParam(), mockNameParam(), mockIntervalParam());
String actualSql = SqlVariableParseUtils.parse(sql, variables, params);
String expectedSql = "select * from t_wk where id = 2 and name = 'alice'";
Assertions.assertEquals(expectedSql, actualSql);
}
private SqlVariable mockNumSqlVariable() {
return mockSqlVariable("id", VariableValueType.NUMBER, 1);
}
private SqlVariable mockStrSqlVariable() {
return mockSqlVariable("name", VariableValueType.STRING, "tom");
}
private SqlVariable mockExprSqlVariable() {
return mockSqlVariable("interval", VariableValueType.EXPR, "d");
}
private SqlVariable mockSqlVariable(String name, VariableValueType variableValueType, Object value) {
SqlVariable sqlVariable = new SqlVariable();
sqlVariable.setName(name);
sqlVariable.setValueType(variableValueType);
sqlVariable.setDefaultValues(Lists.newArrayList(value));
return sqlVariable;
}
private Param mockIdParam() {
return mockParam("id", "2");
}
private Param mockNameParam() {
return mockParam("name", "alice");
}
private Param mockIntervalParam() {
return mockParam("interval", "wk");
}
private Param mockParam(String name, String value) {
return new Param(name, value);
}
}

View File

@@ -74,6 +74,7 @@
<poi.version>3.17</poi.version>
<langchain4j.version>0.24.0</langchain4j.version>
<postgresql.version>42.7.1</postgresql.version>
<st.version>4.0.8</st.version>
</properties>
<dependencyManagement>