From f198ce1ef8e9e1dd36c88288b88fa05679b1025a Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sat, 18 Nov 2023 10:28:09 +0800 Subject: [PATCH] [improvement](chat) If there is no aggregate function in the S2SQL, add the field to the 'SELECT' statement. (#401) --- .../chat/corrector/BaseSemanticCorrector.java | 16 +++++++++++++--- .../common/pojo/enums/TimeDimensionEnum.java | 4 ++++ .../service/impl/SysParameterServiceImpl.java | 4 +++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index 439e695a8..ff034d8d9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.ArrayList; @@ -23,8 +24,8 @@ import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; /** - * basic semantic correction functionality, offering common methods and an - * abstract method called doCorrect + * basic semantic correction functionality, offering common methods and an + * abstract method called doCorrect */ @Slf4j public abstract class BaseSemanticCorrector implements SemanticCorrector { @@ -80,12 +81,21 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { Set needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL)); needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL)); + // If there is no aggregate function in the S2SQL statement and + // there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement. + if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) { + List whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL); + List timeChNameList = TimeDimensionEnum.getChNameList(); + Set timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field)) + .collect(Collectors.toSet()); + needAddFields.addAll(timeFields); + } + if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) { return; } needAddFields.removeAll(selectFields); - needAddFields.remove(TimeDimensionEnum.DAY.getChName()); String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields)); semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields); } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java index 2ab03c342..d27357993 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java @@ -27,6 +27,10 @@ public enum TimeDimensionEnum { return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName).collect(Collectors.toList()); } + public static List getChNameList() { + return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName).collect(Collectors.toList()); + } + public String getName() { return name; } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java index dd87a0276..370ee01ae 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java @@ -39,7 +39,9 @@ public class SysParameterServiceImpl private SysParameter convert(SysParameterDO sysParameterDO) { SysParameter sysParameter = new SysParameter(); sysParameter.setId(sysParameterDO.getId()); - List parameters = JsonUtil.toObject(sysParameterDO.getParameters(), new TypeReference>() {}); + List parameters = JsonUtil.toObject(sysParameterDO.getParameters(), + new TypeReference>() { + }); sysParameter.setParameters(parameters); sysParameter.setAdminList(sysParameterDO.getAdmin()); return sysParameter;