From 28f95ddf9e14322557693d1cca997e81ceb59433 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 27 Jun 2024 23:02:59 +0800 Subject: [PATCH] (improvement)(test) Fix NullPointerException in SelectCorrectorTest unit test (#1260) --- headless/chat/pom.xml | 6 ++++ .../chat/corrector/SelectCorrector.java | 30 +++++++------------ .../chat/corrector/SelectCorrectorTest.java | 16 ++++++++-- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/headless/chat/pom.xml b/headless/chat/pom.xml index 2a25fdd6b..bb6acbd32 100644 --- a/headless/chat/pom.xml +++ b/headless/chat/pom.xml @@ -116,6 +116,12 @@ mockito-core test + + org.mockito + mockito-inline + ${mockito-inline.version} + test + \ No newline at end of file diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java index 7f33e7aad..87826a8db 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java @@ -11,17 +11,17 @@ import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.QueryContext; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.core.env.Environment; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; /** * Perform SQL corrections on the "Select" section in S2SQL. @@ -29,6 +29,8 @@ import java.util.List; @Slf4j public class SelectCorrector extends BaseSemanticCorrector { + public static final String ADDITIONAL_INFORMATION = "s2.corrector.additional.information"; + @Override public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); @@ -46,14 +48,15 @@ public class SelectCorrector extends BaseSemanticCorrector { } protected String addFieldsToSelect(QueryContext queryContext, SemanticParseInfo semanticParseInfo, - String correctS2SQL) { + String correctS2SQL) { correctS2SQL = addTagDefaultFields(queryContext, semanticParseInfo, correctS2SQL); Set selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); //decide whether add order by expression field to select - String correctorAdditionalInfo = getAdditionalInfo(); + Environment environment = ContextUtils.getBean(Environment.class); + String correctorAdditionalInfo = environment.getProperty(ADDITIONAL_INFORMATION); if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) { needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL)); } @@ -67,7 +70,7 @@ public class SelectCorrector extends BaseSemanticCorrector { } private String addTagDefaultFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo, - String correctS2SQL) { + String correctS2SQL) { //If it is in DETAIL mode and select *, add default metrics and dimensions. boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL); if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) { @@ -96,15 +99,4 @@ public class SelectCorrector extends BaseSemanticCorrector { } return correctS2SQL; } - - private String getAdditionalInfo() { - String correctorAdditionalInfo = null; - try { - Environment environment = ContextUtils.getBean(Environment.class); - correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information"); - } catch (Exception e) { - log.error("getAdditionalInfo error:{}", e); - } - return correctorAdditionalInfo; - } } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java index ec99cd2cd..51474df41 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.pojo.enums.QueryType; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo; import com.tencent.supersonic.headless.api.pojo.QueryConfig; @@ -10,12 +11,18 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; import com.tencent.supersonic.headless.chat.QueryContext; +import org.junit.Assert; +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.springframework.core.env.Environment; + import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; -import org.junit.Assert; -import org.junit.jupiter.api.Test; + +import static org.mockito.Mockito.when; class SelectCorrectorTest { @@ -23,6 +30,11 @@ class SelectCorrectorTest { @Test void testDoCorrect() { + MockedStatic mocked = Mockito.mockStatic(ContextUtils.class); + Environment mockEnvironment = Mockito.mock(Environment.class); + mocked.when(() -> ContextUtils.getBean(Environment.class)).thenReturn(mockEnvironment); + when(mockEnvironment.getProperty(SelectCorrector.ADDITIONAL_INFORMATION)).thenReturn(""); + BaseSemanticCorrector corrector = new SelectCorrector(); QueryContext queryContext = buildQueryContext(dataSetId); SemanticParseInfo semanticParseInfo = new SemanticParseInfo();