(improvement)(chat) Decide whether to add or remove dates based on whether the dataset has partition dates. (#1512)

This commit is contained in:
lexluo09
2024-08-04 17:39:23 +08:00
committed by GitHub
parent 97bf8049d7
commit e2e45a40ab
5 changed files with 118 additions and 137 deletions

View File

@@ -122,4 +122,8 @@ public class DataSetSchema {
return new ArrayList<>(); return new ArrayList<>();
} }
public boolean containsPartitionDimensions() {
return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime);
}
} }

View File

@@ -1,16 +1,18 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.tencent.supersonic.common.pojo.DimensionConstants;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.apache.commons.collections4.MapUtils;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Data @Data
@Getter @Getter
@@ -59,4 +61,12 @@ public class SchemaElement implements Serializable {
return Objects.hashCode(dataSetId, id, name, bizName, type); return Objects.hashCode(dataSetId, id, name, bizName, type);
} }
public boolean containsPartitionTime() {
if (MapUtils.isEmpty(extInfo)) {
return false;
}
DimensionType dimensionTYpe = (DimensionType) extInfo.get(DimensionConstants.DIMENSION_TYPE);
return DimensionType.isPartitionTime(dimensionTYpe);
}
} }

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.api.pojo.enums; package com.tencent.supersonic.headless.api.pojo.enums;
public enum DimensionType { public enum DimensionType {
categorical, categorical,
@@ -8,12 +7,19 @@ public enum DimensionType {
partition_time, partition_time,
identify; identify;
public static Boolean isTimeDimension(String type) { public static boolean isTimeDimension(String type) {
return time.name().equals(type) || partition_time.name().equals(type); try {
return isTimeDimension(DimensionType.valueOf(type.toUpperCase()));
} catch (IllegalArgumentException e) {
return false;
}
} }
public static Boolean isTimeDimension(DimensionType type) { public static boolean isTimeDimension(DimensionType type) {
return time.equals(type) || partition_time.equals(type); return type == time || type == partition_time;
} }
public static boolean isPartitionTime(DimensionType type) {
return type == partition_time;
}
} }

View File

@@ -1,31 +1,27 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/** /**
* Perform SQL corrections on the time in S2SQL. * Perform SQL corrections on the time in S2SQL.
*/ */
@@ -34,95 +30,79 @@ public class TimeCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
if (containsPartitionDimensions(chatQueryContext, semanticParseInfo)) {
addDateIfNotExist(chatQueryContext, semanticParseInfo); addDateIfNotExist(chatQueryContext, semanticParseInfo);
} else {
removeDateIfExist(chatQueryContext, semanticParseInfo); removeDateIfExist(semanticParseInfo);
}
addLowerBoundDate(semanticParseInfo); addLowerBoundDate(semanticParseInfo);
} }
private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private void removeDateIfExist(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
//decide whether remove date field from where Set<String> removeFieldNames = new HashSet<>();
Environment environment = ContextUtils.getBean(Environment.class); removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
String correctorDate = environment.getProperty("s2.corrector.date"); removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) { removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
Set<String> removeFieldNames = new HashSet<>(); correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
private boolean checkIfNameInWhereFields(Set<SchemaElement> dims, List<String> whereFields) {
for (SchemaElement element : dims) {
if (whereFields.contains(element.getName())) {
return true;
}
}
return false;
} }
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
//decide whether add date field to where
Environment environment = ContextUtils.getBean(Environment.class);
String correctorDate = environment.getProperty("s2.corrector.date");
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) {
return;
}
Long dataSetId = semanticParseInfo.getDataSetId(); Long dataSetId = semanticParseInfo.getDataSetId();
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
boolean isDateInWhere = checkIfNameInWhereFields(dataSetSchema.getDimensions(), whereFields);
if (isDateInWhere) {
return;
}
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) { if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId,
semanticParseInfo.getQueryType());
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, if (isValidDateRange(startEndDate)) {
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
if (StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight())) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String dateChName = TimeDimensionEnum.DAY.getChName(); String dateChName = TimeDimensionEnum.DAY.getChName();
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName, String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )",
startEndDate.getLeft(), dateChName, startEndDate.getRight()); dateChName, startEndDate.getLeft(), dateChName, startEndDate.getRight());
try { correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
} catch (JSQLParserException e) {
log.error("parseCondExpression:{}", e);
}
} }
} }
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
private boolean containsPartitionDimensions(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
return dataSetSchema.containsPartitionDimensions();
}
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) { private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL); DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
if (Objects.isNull(dateBoundInfo)) {
return; if (dateBoundInfo != null
} && StringUtils.isBlank(dateBoundInfo.getLowerBound())
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound()) && StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) { && StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
String upperDate = dateBoundInfo.getUpperDate(); String upperDate = dateBoundInfo.getUpperDate();
try { String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
} }
private boolean isValidDateRange(Pair<String, String> startEndDate) {
return StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight());
}
private String addConditionToSQL(String sql, String condition) {
try {
Expression expression = CCJSqlParserUtil.parseCondExpression(condition);
return SqlAddHelper.addWhere(sql, expression);
} catch (JSQLParserException e) {
log.error("addConditionToSQL:{}", e);
return sql;
}
}
} }

View File

@@ -4,12 +4,16 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.utils.QueryFilterParser; import com.tencent.supersonic.headless.chat.utils.QueryFilterParser;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
@@ -17,11 +21,6 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/** /**
* Perform SQL corrections on the "Where" section in S2SQL. * Perform SQL corrections on the "Where" section in S2SQL.
*/ */
@@ -30,27 +29,23 @@ public class WhereCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
addQueryFilter(chatQueryContext, semanticParseInfo); addQueryFilter(chatQueryContext, semanticParseInfo);
updateFieldValueByTechName(chatQueryContext, semanticParseInfo); updateFieldValueByTechName(chatQueryContext, semanticParseInfo);
} }
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters()); String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) { if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to correctS2SQL :{}", queryFilter); log.info("add queryFilter to correctS2SQL :{}", queryFilter);
Expression expression = null;
try { try {
expression = CCJSqlParserUtil.parseCondExpression(queryFilter); Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} catch (JSQLParserException e) { } catch (JSQLParserException e) {
log.error("parseCondExpression", e); log.error("parseCondExpression", e);
} }
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
} }
@@ -69,49 +64,35 @@ public class WhereCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(dimensions)) { if (CollectionUtils.isEmpty(dimensions)) {
return; return;
} }
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions); Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), String correctedS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
aliasAndBizNameToTechName); String replaceSql = SqlReplaceHelper.replaceValue(correctedS2SQL, aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
} }
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) { private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
if (CollectionUtils.isEmpty(dimensions)) { return dimensions.stream()
return new HashMap<>(); .filter(dimension -> Objects.nonNull(dimension)
} && StringUtils.isNotEmpty(dimension.getName())
&& !CollectionUtils.isEmpty(dimension.getSchemaValueMaps()))
Map<String, Map<String, String>> result = new HashMap<>(); .collect(Collectors.toMap(
SchemaElement::getName,
for (SchemaElement dimension : dimensions) { dimension -> dimension.getSchemaValueMaps().stream()
if (Objects.isNull(dimension) .filter(valueMap -> Objects.nonNull(valueMap)
|| StringUtils.isEmpty(dimension.getName()) && StringUtils.isNotEmpty(valueMap.getTechName()))
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) { .flatMap(valueMap -> {
continue; Map<String, String> map = new HashMap<>();
} if (StringUtils.isNotEmpty(valueMap.getBizName())) {
String name = dimension.getName(); map.put(valueMap.getBizName(), valueMap.getTechName());
}
Map<String, String> aliasAndBizNameToTechName = new HashMap<>(); if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
valueMap.getAlias().stream()
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) { .filter(StringUtils::isNotEmpty)
if (Objects.isNull(valueMap) || StringUtils.isEmpty(valueMap.getTechName())) { .forEach(alias -> map.put(alias, valueMap.getTechName()));
continue; }
} return map.entrySet().stream();
if (StringUtils.isNotEmpty(valueMap.getBizName())) { })
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName()); .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
} ));
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
valueMap.getAlias().stream().forEach(alias -> {
if (StringUtils.isNotEmpty(alias)) {
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
}
});
}
}
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
result.put(name, aliasAndBizNameToTechName);
}
}
return result;
} }
} }