(improvement)(test) Fix NullPointerException in SelectCorrectorTest unit test (#1260)

This commit is contained in:
lexluo09
2024-06-27 23:02:59 +08:00
committed by GitHub
parent 2fd22e0769
commit 28f95ddf9e
3 changed files with 31 additions and 21 deletions

View File

@@ -116,6 +116,12 @@
<artifactId>mockito-core</artifactId> <artifactId>mockito-core</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito-inline.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@@ -11,17 +11,17 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.QueryContext; import com.tencent.supersonic.headless.chat.QueryContext;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/** /**
* Perform SQL corrections on the "Select" section in S2SQL. * Perform SQL corrections on the "Select" section in S2SQL.
@@ -29,6 +29,8 @@ import java.util.List;
@Slf4j @Slf4j
public class SelectCorrector extends BaseSemanticCorrector { public class SelectCorrector extends BaseSemanticCorrector {
public static final String ADDITIONAL_INFORMATION = "s2.corrector.additional.information";
@Override @Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
@@ -46,14 +48,15 @@ public class SelectCorrector extends BaseSemanticCorrector {
} }
protected String addFieldsToSelect(QueryContext queryContext, SemanticParseInfo semanticParseInfo, protected String addFieldsToSelect(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
String correctS2SQL) { String correctS2SQL) {
correctS2SQL = addTagDefaultFields(queryContext, semanticParseInfo, correctS2SQL); correctS2SQL = addTagDefaultFields(queryContext, semanticParseInfo, correctS2SQL);
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
//decide whether add order by expression field to select //decide whether add order by expression field to select
String correctorAdditionalInfo = getAdditionalInfo(); Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty(ADDITIONAL_INFORMATION);
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) { if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL)); needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
} }
@@ -67,7 +70,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
} }
private String addTagDefaultFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo, private String addTagDefaultFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
String correctS2SQL) { String correctS2SQL) {
//If it is in DETAIL mode and select *, add default metrics and dimensions. //If it is in DETAIL mode and select *, add default metrics and dimensions.
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL); boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) { if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
@@ -96,15 +99,4 @@ public class SelectCorrector extends BaseSemanticCorrector {
} }
return correctS2SQL; return correctS2SQL;
} }
private String getAdditionalInfo() {
String correctorAdditionalInfo = null;
try {
Environment environment = ContextUtils.getBean(Environment.class);
correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
} catch (Exception e) {
log.error("getAdditionalInfo error:{}", e);
}
return correctorAdditionalInfo;
}
} }

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo; import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.QueryConfig; import com.tencent.supersonic.headless.api.pojo.QueryConfig;
@@ -10,12 +11,18 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.chat.QueryContext; import com.tencent.supersonic.headless.chat.QueryContext;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.springframework.core.env.Environment;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import org.junit.Assert;
import org.junit.jupiter.api.Test; import static org.mockito.Mockito.when;
class SelectCorrectorTest { class SelectCorrectorTest {
@@ -23,6 +30,11 @@ class SelectCorrectorTest {
@Test @Test
void testDoCorrect() { void testDoCorrect() {
MockedStatic<ContextUtils> mocked = Mockito.mockStatic(ContextUtils.class);
Environment mockEnvironment = Mockito.mock(Environment.class);
mocked.when(() -> ContextUtils.getBean(Environment.class)).thenReturn(mockEnvironment);
when(mockEnvironment.getProperty(SelectCorrector.ADDITIONAL_INFORMATION)).thenReturn("");
BaseSemanticCorrector corrector = new SelectCorrector(); BaseSemanticCorrector corrector = new SelectCorrector();
QueryContext queryContext = buildQueryContext(dataSetId); QueryContext queryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SemanticParseInfo semanticParseInfo = new SemanticParseInfo();