(improvement)(Headless) support validation of special characters when creating metrics, dimensions, etc. (#914)

This commit is contained in:
lexluo09
2024-04-16 22:30:10 +08:00
committed by GitHub
parent fd7de6255a
commit 5672aade1d
6 changed files with 80 additions and 28 deletions

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.server.rest.api;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq;
@@ -10,6 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
@@ -32,6 +34,9 @@ public class SqlQueryApiController {
@Autowired
private QueryService queryService;
@Autowired
private SemanticService semanticService;
@PostMapping("/sql")
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq,
HttpServletRequest request,
@@ -59,6 +64,8 @@ public class SqlQueryApiController {
private void correct(QuerySqlReq querySqlReq) {
QueryContext queryCtx = new QueryContext();
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());

View File

@@ -45,6 +45,7 @@ import com.tencent.supersonic.headless.server.service.TagMetaService;
import com.tencent.supersonic.headless.server.utils.DimensionConverter;
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher;
@@ -83,12 +84,12 @@ public class DimensionServiceImpl implements DimensionService {
public DimensionServiceImpl(DimensionRepository dimensionRepository,
ModelService modelService,
ChatGptHelper chatGptHelper,
DatabaseService databaseService,
ModelRelaService modelRelaService,
DataSetService dataSetService,
TagMetaService tagMetaService) {
ModelService modelService,
ChatGptHelper chatGptHelper,
DatabaseService databaseService,
ModelRelaService modelRelaService,
DataSetService dataSetService,
TagMetaService tagMetaService) {
this.modelService = modelService;
this.dimensionRepository = dimensionRepository;
this.chatGptHelper = chatGptHelper;
@@ -398,8 +399,10 @@ public class DimensionServiceImpl implements DimensionService {
Map<String, DimensionResp> nameMap = dimensionResps.stream()
.collect(Collectors.toMap(DimensionResp::getName, a -> a, (k1, k2) -> k1));
for (DimensionReq dimensionReq : dimensionReqs) {
if (NameCheckUtils.containsSpecialCharacters(dimensionReq.getName())) {
throw new InvalidArgumentException("名称包含特殊字符, 请修改: " + dimensionReq.getName());
String forbiddenCharacters = NameCheckUtils.findForbiddenCharacters(dimensionReq.getName());
if (StringUtils.isNotBlank(forbiddenCharacters)) {
throw new InvalidArgumentException(String.format("名称包含特殊字符, 请修改: %s特殊字符: %s",
dimensionReq.getBizName(), forbiddenCharacters));
}
if (bizNameMap.containsKey(dimensionReq.getBizName())) {
DimensionResp dimensionResp = bizNameMap.get(dimensionReq.getBizName());

View File

@@ -81,13 +81,13 @@ public class ModelServiceImpl implements ModelService {
private DateInfoRepository dateInfoRepository;
public ModelServiceImpl(ModelRepository modelRepository,
DatabaseService databaseService,
@Lazy DimensionService dimensionService,
@Lazy MetricService metricService,
DomainService domainService,
UserService userService,
DataSetService dataSetService,
DateInfoRepository dateInfoRepository) {
DatabaseService databaseService,
@Lazy DimensionService dimensionService,
@Lazy MetricService metricService,
DomainService domainService,
UserService userService,
DataSetService dataSetService,
DateInfoRepository dateInfoRepository) {
this.modelRepository = modelRepository;
this.databaseService = databaseService;
this.dimensionService = dimensionService;
@@ -217,8 +217,9 @@ public class ModelServiceImpl implements ModelService {
}
private void checkName(ModelReq modelReq) {
if (NameCheckUtils.containsSpecialCharacters(modelReq.getName())) {
String message = String.format("模型名称[%s]包含特殊字符, 请修改", modelReq.getName());
String forbiddenCharacters = NameCheckUtils.findForbiddenCharacters(modelReq.getName());
if (StringUtils.isNotBlank(forbiddenCharacters)) {
String message = String.format("模型名称[%s]包含特殊字符(%s), 请修改", modelReq.getName(), forbiddenCharacters);
throw new InvalidArgumentException(message);
}
List<Dim> dims = modelReq.getModelDetail().getDimensions();
@@ -232,23 +233,27 @@ public class ModelServiceImpl implements ModelService {
throw new InvalidArgumentException("有度量时, 不可缺少时间维度");
}
for (Measure measure : measures) {
String measureForbiddenCharacters = NameCheckUtils.findForbiddenCharacters(measure.getName());
if (StringUtils.isNotBlank(measure.getName())
&& NameCheckUtils.containsSpecialCharacters(measure.getName())) {
String message = String.format("度量[%s]包含特殊字符, 请修改", measure.getName());
&& StringUtils.isNotBlank(measureForbiddenCharacters)) {
String message = String.format("度量[%s]包含特殊字符(%s), 请修改", measure.getName(), measureForbiddenCharacters);
throw new InvalidArgumentException(message);
}
}
for (Dim dim : dims) {
String dimForbiddenCharacters = NameCheckUtils.findForbiddenCharacters(dim.getName());
if (StringUtils.isNotBlank(dim.getName())
&& NameCheckUtils.containsSpecialCharacters(dim.getName())) {
String message = String.format("维度[%s]包含特殊字符, 请修改", dim.getName());
&& StringUtils.isNotBlank(dimForbiddenCharacters)) {
String message = String.format("维度[%s]包含特殊字符(%s), 请修改", dim.getName(), dimForbiddenCharacters);
throw new InvalidArgumentException(message);
}
}
for (Identify identify : identifies) {
String identifyForbiddenCharacters = NameCheckUtils.findForbiddenCharacters(identify.getName());
if (StringUtils.isNotBlank(identify.getName())
&& NameCheckUtils.containsSpecialCharacters(identify.getName())) {
String message = String.format("主键/外键[%s]包含特殊字符, 请修改", identify.getName());
&& StringUtils.isNotBlank(identifyForbiddenCharacters)) {
String message = String.format("主键/外键[%s]包含特殊字符(%s), 请修改", identify.getName(),
identifyForbiddenCharacters);
throw new InvalidArgumentException(message);
}
}

View File

@@ -56,8 +56,9 @@ public class MetricCheckUtils {
if (StringUtils.isBlank(expr)) {
throw new InvalidArgumentException("表达式不可为空");
}
if (NameCheckUtils.containsSpecialCharacters(metricReq.getName())) {
throw new InvalidArgumentException("名称包含特殊字符, 请修改");
String forbiddenCharacters = NameCheckUtils.findForbiddenCharacters(metricReq.getName());
if (StringUtils.isNotBlank(forbiddenCharacters)) {
throw new InvalidArgumentException(String.format("名称包含特殊字符%s, 请修改", forbiddenCharacters));
}
}

View File

@@ -1,9 +1,25 @@
package com.tencent.supersonic.headless.server.utils;
public class NameCheckUtils {
import org.apache.commons.lang3.StringUtils;
public static boolean containsSpecialCharacters(String str) {
return false;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class NameCheckUtils {
public static final String forbiddenCharactersRegex = "[%#()]";
public static String findForbiddenCharacters(String str) {
if (StringUtils.isBlank(str)) {
return "";
}
Pattern pattern = Pattern.compile(forbiddenCharactersRegex);
Matcher matcher = pattern.matcher(str);
StringBuilder foundCharacters = new StringBuilder();
while (matcher.find()) {
foundCharacters.append(matcher.group()).append(" ");
}
return foundCharacters.toString().trim();
}
}

View File

@@ -0,0 +1,20 @@
package com.tencent.supersonic.headless.server.utils;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test;
import org.testng.Assert;
class NameCheckUtilsTest {
@Test
void findForbiddenCharacters() {
Assert.assertTrue(StringUtils.isBlank(NameCheckUtils.findForbiddenCharacters("访问时长")));
Assert.assertTrue(StringUtils.isNotBlank(NameCheckUtils.findForbiddenCharacters("访问时长(秒)")));
Assert.assertTrue(StringUtils.isNotBlank(NameCheckUtils.findForbiddenCharacters("访问时长#")));
Assert.assertTrue(StringUtils.isNotBlank(NameCheckUtils.findForbiddenCharacters("访问时长%")));
Assert.assertTrue(StringUtils.isNotBlank(NameCheckUtils.findForbiddenCharacters("访问时长(")));
Assert.assertTrue(StringUtils.isNotBlank(NameCheckUtils.findForbiddenCharacters("访问时长)")));
Assert.assertTrue(StringUtils.isNotBlank(NameCheckUtils.findForbiddenCharacters("访问时长(")));
Assert.assertTrue(StringUtils.isNotBlank(NameCheckUtils.findForbiddenCharacters("访问时长)")));
}
}