From 28f95ddf9e14322557693d1cca997e81ceb59433 Mon Sep 17 00:00:00 2001
From: lexluo09 <39718951+lexluo09@users.noreply.github.com>
Date: Thu, 27 Jun 2024 23:02:59 +0800
Subject: [PATCH] (improvement)(test) Fix NullPointerException in
SelectCorrectorTest unit test (#1260)
---
headless/chat/pom.xml | 6 ++++
.../chat/corrector/SelectCorrector.java | 30 +++++++------------
.../chat/corrector/SelectCorrectorTest.java | 16 ++++++++--
3 files changed, 31 insertions(+), 21 deletions(-)
diff --git a/headless/chat/pom.xml b/headless/chat/pom.xml
index 2a25fdd6b..bb6acbd32 100644
--- a/headless/chat/pom.xml
+++ b/headless/chat/pom.xml
@@ -116,6 +116,12 @@
mockito-core
test
+
+ org.mockito
+ mockito-inline
+ ${mockito-inline.version}
+ test
+
\ No newline at end of file
diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java
index 7f33e7aad..87826a8db 100644
--- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java
+++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java
@@ -11,17 +11,17 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.QueryContext;
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.Objects;
-import java.util.Set;
-import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
+import java.util.ArrayList;
+import java.util.HashSet;
import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
/**
* Perform SQL corrections on the "Select" section in S2SQL.
@@ -29,6 +29,8 @@ import java.util.List;
@Slf4j
public class SelectCorrector extends BaseSemanticCorrector {
+ public static final String ADDITIONAL_INFORMATION = "s2.corrector.additional.information";
+
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
@@ -46,14 +48,15 @@ public class SelectCorrector extends BaseSemanticCorrector {
}
protected String addFieldsToSelect(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
- String correctS2SQL) {
+ String correctS2SQL) {
correctS2SQL = addTagDefaultFields(queryContext, semanticParseInfo, correctS2SQL);
Set selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
//decide whether add order by expression field to select
- String correctorAdditionalInfo = getAdditionalInfo();
+ Environment environment = ContextUtils.getBean(Environment.class);
+ String correctorAdditionalInfo = environment.getProperty(ADDITIONAL_INFORMATION);
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
}
@@ -67,7 +70,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
}
private String addTagDefaultFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
- String correctS2SQL) {
+ String correctS2SQL) {
//If it is in DETAIL mode and select *, add default metrics and dimensions.
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
@@ -96,15 +99,4 @@ public class SelectCorrector extends BaseSemanticCorrector {
}
return correctS2SQL;
}
-
- private String getAdditionalInfo() {
- String correctorAdditionalInfo = null;
- try {
- Environment environment = ContextUtils.getBean(Environment.class);
- correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
- } catch (Exception e) {
- log.error("getAdditionalInfo error:{}", e);
- }
- return correctorAdditionalInfo;
- }
}
diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java
index ec99cd2cd..51474df41 100644
--- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java
+++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java
@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.QueryType;
+import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
@@ -10,12 +11,18 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.chat.QueryContext;
+import org.junit.Assert;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.springframework.core.env.Environment;
+
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
-import org.junit.Assert;
-import org.junit.jupiter.api.Test;
+
+import static org.mockito.Mockito.when;
class SelectCorrectorTest {
@@ -23,6 +30,11 @@ class SelectCorrectorTest {
@Test
void testDoCorrect() {
+ MockedStatic mocked = Mockito.mockStatic(ContextUtils.class);
+ Environment mockEnvironment = Mockito.mock(Environment.class);
+ mocked.when(() -> ContextUtils.getBean(Environment.class)).thenReturn(mockEnvironment);
+ when(mockEnvironment.getProperty(SelectCorrector.ADDITIONAL_INFORMATION)).thenReturn("");
+
BaseSemanticCorrector corrector = new SelectCorrector();
QueryContext queryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();