mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(headless) add sql variable for model sql (#740)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -27,6 +27,8 @@ public class ModelDetail {
|
|||||||
|
|
||||||
private List<Field> fields = Lists.newArrayList();
|
private List<Field> fields = Lists.newArrayList();
|
||||||
|
|
||||||
|
private List<SqlVariable> sqlVariables = Lists.newArrayList();
|
||||||
|
|
||||||
public String getSqlQuery() {
|
public String getSqlQuery() {
|
||||||
if (StringUtils.isNotBlank(sqlQuery) && sqlQuery.endsWith(";")) {
|
if (StringUtils.isNotBlank(sqlQuery) && sqlQuery.endsWith(";")) {
|
||||||
sqlQuery = sqlQuery.substring(0, sqlQuery.length() - 1);
|
sqlQuery = sqlQuery.substring(0, sqlQuery.length() - 1);
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.tencent.supersonic.headless.api.pojo;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SqlVariable {
|
||||||
|
private String name;
|
||||||
|
private VariableValueType valueType;
|
||||||
|
private List<Object> defaultValues = Lists.newArrayList();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||||
|
|
||||||
|
public enum VariableValueType {
|
||||||
|
STRING,
|
||||||
|
NUMBER,
|
||||||
|
EXPR
|
||||||
|
}
|
||||||
@@ -24,6 +24,11 @@
|
|||||||
<version>${lombok.version}</version>
|
<version>${lombok.version}</version>
|
||||||
<scope>provided</scope>
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.antlr</groupId>
|
||||||
|
<artifactId>ST4</artifactId>
|
||||||
|
<version>${st.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework</groupId>
|
<groupId>org.springframework</groupId>
|
||||||
|
|||||||
@@ -0,0 +1,120 @@
|
|||||||
|
package com.tencent.supersonic.headless.core.utils;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.Param;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SqlVariable;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
import org.stringtemplate.v4.ST;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
|
||||||
|
import static com.tencent.supersonic.common.pojo.Constants.EMPTY;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class SqlVariableParseUtils {
|
||||||
|
|
||||||
|
public static final String REG_SENSITIVE_SQL = "drop\\s|alter\\s|grant\\s|insert\\s|replace\\s|delete\\s|"
|
||||||
|
+ "truncate\\s|update\\s|remove\\s";
|
||||||
|
public static final Pattern PATTERN_SENSITIVE_SQL = Pattern.compile(REG_SENSITIVE_SQL);
|
||||||
|
|
||||||
|
public static final String APOSTROPHE = "'";
|
||||||
|
|
||||||
|
private static final char delimiter = '$';
|
||||||
|
|
||||||
|
public static String parse(String sql, List<SqlVariable> sqlVariables, List<Param> params) {
|
||||||
|
if (CollectionUtils.isEmpty(sqlVariables)) {
|
||||||
|
return sql;
|
||||||
|
}
|
||||||
|
Map<String, Object> queryParams = new HashMap<>();
|
||||||
|
//1. handle default variable value
|
||||||
|
sqlVariables.forEach(variable -> {
|
||||||
|
queryParams.put(variable.getName().trim(),
|
||||||
|
getValues(variable.getValueType(), variable.getDefaultValues()));
|
||||||
|
});
|
||||||
|
|
||||||
|
//override by variable param
|
||||||
|
if (!CollectionUtils.isEmpty(params)) {
|
||||||
|
Map<String, List<SqlVariable>> map =
|
||||||
|
sqlVariables.stream().collect(Collectors.groupingBy(SqlVariable::getName));
|
||||||
|
params.forEach(p -> {
|
||||||
|
if (map.containsKey(p.getName())) {
|
||||||
|
List<SqlVariable> list = map.get(p.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(list)) {
|
||||||
|
SqlVariable v = list.get(list.size() - 1);
|
||||||
|
queryParams.put(p.getName().trim(), getValue(v.getValueType(), p.getValue()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
queryParams.forEach((k, v) -> {
|
||||||
|
if (v instanceof List && ((List) v).size() > 0) {
|
||||||
|
v = ((List) v).stream().collect(Collectors.joining(COMMA)).toString();
|
||||||
|
}
|
||||||
|
queryParams.put(k, v);
|
||||||
|
});
|
||||||
|
ST st = new ST(sql, delimiter, delimiter);
|
||||||
|
if (!CollectionUtils.isEmpty(queryParams)) {
|
||||||
|
queryParams.forEach(st::add);
|
||||||
|
}
|
||||||
|
return st.render();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<String> getValues(VariableValueType valueType, List<Object> values) {
|
||||||
|
if (CollectionUtils.isEmpty(values)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
if (null != valueType) {
|
||||||
|
switch (valueType) {
|
||||||
|
case STRING:
|
||||||
|
return values.stream().map(String::valueOf)
|
||||||
|
.map(s -> s.startsWith(APOSTROPHE) && s.endsWith(APOSTROPHE)
|
||||||
|
? s : String.join(EMPTY, APOSTROPHE, s, APOSTROPHE))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
case EXPR:
|
||||||
|
values.stream().map(String::valueOf).forEach(SqlVariableParseUtils::checkSensitiveSql);
|
||||||
|
return values.stream().map(String::valueOf).collect(Collectors.toList());
|
||||||
|
case NUMBER:
|
||||||
|
return values.stream().map(String::valueOf).collect(Collectors.toList());
|
||||||
|
default:
|
||||||
|
return values.stream().map(String::valueOf).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values.stream().map(String::valueOf).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Object getValue(VariableValueType valueType, String value) {
|
||||||
|
if (!StringUtils.isEmpty(value)) {
|
||||||
|
if (null != valueType) {
|
||||||
|
switch (valueType) {
|
||||||
|
case STRING:
|
||||||
|
return String.join(EMPTY, value.startsWith(APOSTROPHE) ? EMPTY : APOSTROPHE,
|
||||||
|
value, value.endsWith(APOSTROPHE) ? EMPTY : APOSTROPHE);
|
||||||
|
case NUMBER:
|
||||||
|
case EXPR:
|
||||||
|
default:
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void checkSensitiveSql(String sql) {
|
||||||
|
Matcher matcher = PATTERN_SENSITIVE_SQL.matcher(sql.toLowerCase());
|
||||||
|
if (matcher.find()) {
|
||||||
|
String group = matcher.group();
|
||||||
|
log.warn("Sensitive SQL operations are not allowed: {}", group.toUpperCase());
|
||||||
|
throw new InvalidArgumentException("Sensitive SQL operations are not allowed: " + group.toUpperCase());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package com.tencent.supersonic.headless.server.utils;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.Param;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SqlVariable;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
|
||||||
|
import com.tencent.supersonic.headless.core.utils.SqlVariableParseUtils;
|
||||||
|
import org.junit.jupiter.api.Assertions;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class SqlVariableParseUtilsTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testParseSql_defaultVariableValue() {
|
||||||
|
String sql = "select * from t_$interval$ where id = $id$ and name = $name$";
|
||||||
|
List<SqlVariable> variables = Lists.newArrayList(mockNumSqlVariable(),
|
||||||
|
mockExprSqlVariable(), mockStrSqlVariable());
|
||||||
|
String actualSql = SqlVariableParseUtils.parse(sql, variables, Lists.newArrayList());
|
||||||
|
String expectedSql = "select * from t_d where id = 1 and name = 'tom'";
|
||||||
|
Assertions.assertEquals(expectedSql, actualSql);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testParseSql() {
|
||||||
|
String sql = "select * from t_$interval$ where id = $id$ and name = $name$";
|
||||||
|
List<SqlVariable> variables = Lists.newArrayList(mockNumSqlVariable(),
|
||||||
|
mockExprSqlVariable(), mockStrSqlVariable());
|
||||||
|
List<Param> params = Lists.newArrayList(mockIdParam(), mockNameParam(), mockIntervalParam());
|
||||||
|
String actualSql = SqlVariableParseUtils.parse(sql, variables, params);
|
||||||
|
String expectedSql = "select * from t_wk where id = 2 and name = 'alice'";
|
||||||
|
Assertions.assertEquals(expectedSql, actualSql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SqlVariable mockNumSqlVariable() {
|
||||||
|
return mockSqlVariable("id", VariableValueType.NUMBER, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SqlVariable mockStrSqlVariable() {
|
||||||
|
return mockSqlVariable("name", VariableValueType.STRING, "tom");
|
||||||
|
}
|
||||||
|
|
||||||
|
private SqlVariable mockExprSqlVariable() {
|
||||||
|
return mockSqlVariable("interval", VariableValueType.EXPR, "d");
|
||||||
|
}
|
||||||
|
|
||||||
|
private SqlVariable mockSqlVariable(String name, VariableValueType variableValueType, Object value) {
|
||||||
|
SqlVariable sqlVariable = new SqlVariable();
|
||||||
|
sqlVariable.setName(name);
|
||||||
|
sqlVariable.setValueType(variableValueType);
|
||||||
|
sqlVariable.setDefaultValues(Lists.newArrayList(value));
|
||||||
|
return sqlVariable;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Param mockIdParam() {
|
||||||
|
return mockParam("id", "2");
|
||||||
|
}
|
||||||
|
|
||||||
|
private Param mockNameParam() {
|
||||||
|
return mockParam("name", "alice");
|
||||||
|
}
|
||||||
|
|
||||||
|
private Param mockIntervalParam() {
|
||||||
|
return mockParam("interval", "wk");
|
||||||
|
}
|
||||||
|
|
||||||
|
private Param mockParam(String name, String value) {
|
||||||
|
return new Param(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
1
pom.xml
1
pom.xml
@@ -74,6 +74,7 @@
|
|||||||
<poi.version>3.17</poi.version>
|
<poi.version>3.17</poi.version>
|
||||||
<langchain4j.version>0.24.0</langchain4j.version>
|
<langchain4j.version>0.24.0</langchain4j.version>
|
||||||
<postgresql.version>42.7.1</postgresql.version>
|
<postgresql.version>42.7.1</postgresql.version>
|
||||||
|
<st.version>4.0.8</st.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencyManagement>
|
<dependencyManagement>
|
||||||
|
|||||||
Reference in New Issue
Block a user