diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java index 8ff0a620e..2e2bc9b35 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java @@ -122,4 +122,8 @@ public class DataSetSchema { return new ArrayList<>(); } + public boolean containsPartitionDimensions() { + return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime); + } + } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index 8022c0fa2..7beb729af 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -1,16 +1,18 @@ package com.tencent.supersonic.headless.api.pojo; import com.google.common.base.Objects; +import com.tencent.supersonic.common.pojo.DimensionConstants; +import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; +import java.io.Serializable; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; - -import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import org.apache.commons.collections4.MapUtils; @Data @Getter @@ -59,4 +61,12 @@ public class SchemaElement implements Serializable { return Objects.hashCode(dataSetId, id, name, bizName, type); } + public boolean containsPartitionTime() { + if (MapUtils.isEmpty(extInfo)) { + return false; + } + DimensionType dimensionTYpe = (DimensionType) extInfo.get(DimensionConstants.DIMENSION_TYPE); + return DimensionType.isPartitionTime(dimensionTYpe); + } + } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java index 355e83536..b1e98960a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; - public enum DimensionType { categorical, @@ -8,12 +7,19 @@ public enum DimensionType { partition_time, identify; - public static Boolean isTimeDimension(String type) { - return time.name().equals(type) || partition_time.name().equals(type); + public static boolean isTimeDimension(String type) { + try { + return isTimeDimension(DimensionType.valueOf(type.toUpperCase())); + } catch (IllegalArgumentException e) { + return false; + } } - public static Boolean isTimeDimension(DimensionType type) { - return time.equals(type) || partition_time.equals(type); + public static boolean isTimeDimension(DimensionType type) { + return type == time || type == partition_time; } + public static boolean isPartitionTime(DimensionType type) { + return type == partition_time; + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index b54fcc2ab..444554c84 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -1,31 +1,27 @@ package com.tencent.supersonic.headless.chat.corrector; -import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; -import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper; -import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; -import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; +import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; -import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.chat.ChatQueryContext; +import java.util.HashSet; +import java.util.List; +import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; -import org.springframework.core.env.Environment; import org.springframework.util.CollectionUtils; -import java.util.HashSet; -import java.util.List; -import java.util.Objects; -import java.util.Set; - /** * Perform SQL corrections on the time in S2SQL. */ @@ -34,95 +30,79 @@ public class TimeCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - - addDateIfNotExist(chatQueryContext, semanticParseInfo); - - removeDateIfExist(chatQueryContext, semanticParseInfo); - + if (containsPartitionDimensions(chatQueryContext, semanticParseInfo)) { + addDateIfNotExist(chatQueryContext, semanticParseInfo); + } else { + removeDateIfExist(semanticParseInfo); + } addLowerBoundDate(semanticParseInfo); - } - private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + private void removeDateIfExist(SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); - //decide whether remove date field from where - Environment environment = ContextUtils.getBean(Environment.class); - String correctorDate = environment.getProperty("s2.corrector.date"); - if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) { - Set removeFieldNames = new HashSet<>(); - removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); - removeFieldNames.add(TimeDimensionEnum.WEEK.getChName()); - removeFieldNames.add(TimeDimensionEnum.MONTH.getChName()); - correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); - semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); - } - } - - private boolean checkIfNameInWhereFields(Set dims, List whereFields) { - for (SchemaElement element : dims) { - if (whereFields.contains(element.getName())) { - return true; - } - } - return false; + Set removeFieldNames = new HashSet<>(); + removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); + removeFieldNames.add(TimeDimensionEnum.WEEK.getChName()); + removeFieldNames.add(TimeDimensionEnum.MONTH.getChName()); + correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); List whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); - - //decide whether add date field to where - Environment environment = ContextUtils.getBean(Environment.class); - String correctorDate = environment.getProperty("s2.corrector.date"); - if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) { - return; - } Long dataSetId = semanticParseInfo.getDataSetId(); - DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); - boolean isDateInWhere = checkIfNameInWhereFields(dataSetSchema.getDimensions(), whereFields); - if (isDateInWhere) { - return; - } + if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) { + Pair startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, + semanticParseInfo.getQueryType()); - Pair startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, - semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType()); - - if (StringUtils.isNotBlank(startEndDate.getLeft()) - && StringUtils.isNotBlank(startEndDate.getRight())) { + if (isValidDateRange(startEndDate)) { correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); String dateChName = TimeDimensionEnum.DAY.getChName(); - String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName, - startEndDate.getLeft(), dateChName, startEndDate.getRight()); - try { - Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr); - correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression); - } catch (JSQLParserException e) { - log.error("parseCondExpression:{}", e); - } + String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", + dateChName, startEndDate.getLeft(), dateChName, startEndDate.getRight()); + correctS2SQL = addConditionToSQL(correctS2SQL, condExpr); } } semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } + private boolean containsPartitionDimensions(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { + Long dataSetId = semanticParseInfo.getDataSetId(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); + DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); + return dataSetSchema.containsPartitionDimensions(); + } + private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL); - if (Objects.isNull(dateBoundInfo)) { - return; - } - if (StringUtils.isBlank(dateBoundInfo.getLowerBound()) + + if (dateBoundInfo != null + && StringUtils.isBlank(dateBoundInfo.getLowerBound()) && StringUtils.isNotBlank(dateBoundInfo.getUpperBound()) && StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) { String upperDate = dateBoundInfo.getUpperDate(); - try { - correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); - String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'"; - correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr)); - } catch (JSQLParserException e) { - log.error("parseCondExpression", e); - } + String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'"; + correctS2SQL = addConditionToSQL(correctS2SQL, condExpr); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } } + + private boolean isValidDateRange(Pair startEndDate) { + return StringUtils.isNotBlank(startEndDate.getLeft()) + && StringUtils.isNotBlank(startEndDate.getRight()); + } + + private String addConditionToSQL(String sql, String condition) { + try { + Expression expression = CCJSqlParserUtil.parseCondExpression(condition); + return SqlAddHelper.addWhere(sql, expression); + } catch (JSQLParserException e) { + log.error("addConditionToSQL:{}", e); + return sql; + } + } } \ No newline at end of file diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java index 4c8421598..bf645a961 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java @@ -4,12 +4,16 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.headless.api.pojo.SchemaElement; -import com.tencent.supersonic.headless.api.pojo.SchemaValueMap; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.utils.QueryFilterParser; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; @@ -17,11 +21,6 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; - /** * Perform SQL corrections on the "Where" section in S2SQL. */ @@ -30,27 +29,23 @@ public class WhereCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - addQueryFilter(chatQueryContext, semanticParseInfo); - updateFieldValueByTechName(chatQueryContext, semanticParseInfo); } protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters()); - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); if (StringUtils.isNotEmpty(queryFilter)) { log.info("add queryFilter to correctS2SQL :{}", queryFilter); - Expression expression = null; try { - expression = CCJSqlParserUtil.parseCondExpression(queryFilter); + Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); + correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } catch (JSQLParserException e) { log.error("parseCondExpression", e); } - correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression); - semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } } @@ -69,49 +64,35 @@ public class WhereCorrector extends BaseSemanticCorrector { if (CollectionUtils.isEmpty(dimensions)) { return; } - Map> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions); - String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), - aliasAndBizNameToTechName); - semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); + String correctedS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); + String replaceSql = SqlReplaceHelper.replaceValue(correctedS2SQL, aliasAndBizNameToTechName); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql); } private Map> getAliasAndBizNameToTechName(List dimensions) { - if (CollectionUtils.isEmpty(dimensions)) { - return new HashMap<>(); - } - - Map> result = new HashMap<>(); - - for (SchemaElement dimension : dimensions) { - if (Objects.isNull(dimension) - || StringUtils.isEmpty(dimension.getName()) - || CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) { - continue; - } - String name = dimension.getName(); - - Map aliasAndBizNameToTechName = new HashMap<>(); - - for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) { - if (Objects.isNull(valueMap) || StringUtils.isEmpty(valueMap.getTechName())) { - continue; - } - if (StringUtils.isNotEmpty(valueMap.getBizName())) { - aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName()); - } - if (!CollectionUtils.isEmpty(valueMap.getAlias())) { - valueMap.getAlias().stream().forEach(alias -> { - if (StringUtils.isNotEmpty(alias)) { - aliasAndBizNameToTechName.put(alias, valueMap.getTechName()); - } - }); - } - } - if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) { - result.put(name, aliasAndBizNameToTechName); - } - } - return result; + return dimensions.stream() + .filter(dimension -> Objects.nonNull(dimension) + && StringUtils.isNotEmpty(dimension.getName()) + && !CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) + .collect(Collectors.toMap( + SchemaElement::getName, + dimension -> dimension.getSchemaValueMaps().stream() + .filter(valueMap -> Objects.nonNull(valueMap) + && StringUtils.isNotEmpty(valueMap.getTechName())) + .flatMap(valueMap -> { + Map map = new HashMap<>(); + if (StringUtils.isNotEmpty(valueMap.getBizName())) { + map.put(valueMap.getBizName(), valueMap.getTechName()); + } + if (!CollectionUtils.isEmpty(valueMap.getAlias())) { + valueMap.getAlias().stream() + .filter(StringUtils::isNotEmpty) + .forEach(alias -> map.put(alias, valueMap.getTechName())); + } + return map.entrySet().stream(); + }) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + )); } }