(improvement)(chat) Decide whether to add or remove dates based on whether the dataset has partition dates. (#1512)

This commit is contained in:
lexluo09
2024-08-04 17:39:23 +08:00
committed by GitHub
parent 97bf8049d7
commit e2e45a40ab
5 changed files with 118 additions and 137 deletions

View File

@@ -122,4 +122,8 @@ public class DataSetSchema {
return new ArrayList<>();
}
public boolean containsPartitionDimensions() {
return dimensions.stream().anyMatch(SchemaElement::containsPartitionTime);
}
}

View File

@@ -1,16 +1,18 @@
package com.tencent.supersonic.headless.api.pojo;
import com.google.common.base.Objects;
import com.tencent.supersonic.common.pojo.DimensionConstants;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections4.MapUtils;
@Data
@Getter
@@ -59,4 +61,12 @@ public class SchemaElement implements Serializable {
return Objects.hashCode(dataSetId, id, name, bizName, type);
}
public boolean containsPartitionTime() {
if (MapUtils.isEmpty(extInfo)) {
return false;
}
DimensionType dimensionTYpe = (DimensionType) extInfo.get(DimensionConstants.DIMENSION_TYPE);
return DimensionType.isPartitionTime(dimensionTYpe);
}
}

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum DimensionType {
categorical,
@@ -8,12 +7,19 @@ public enum DimensionType {
partition_time,
identify;
public static Boolean isTimeDimension(String type) {
return time.name().equals(type) || partition_time.name().equals(type);
public static boolean isTimeDimension(String type) {
try {
return isTimeDimension(DimensionType.valueOf(type.toUpperCase()));
} catch (IllegalArgumentException e) {
return false;
}
}
public static Boolean isTimeDimension(DimensionType type) {
return time.equals(type) || partition_time.equals(type);
public static boolean isTimeDimension(DimensionType type) {
return type == time || type == partition_time;
}
public static boolean isPartitionTime(DimensionType type) {
return type == partition_time;
}
}

View File

@@ -1,31 +1,27 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* Perform SQL corrections on the time in S2SQL.
*/
@@ -34,95 +30,79 @@ public class TimeCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
addDateIfNotExist(chatQueryContext, semanticParseInfo);
removeDateIfExist(chatQueryContext, semanticParseInfo);
if (containsPartitionDimensions(chatQueryContext, semanticParseInfo)) {
addDateIfNotExist(chatQueryContext, semanticParseInfo);
} else {
removeDateIfExist(semanticParseInfo);
}
addLowerBoundDate(semanticParseInfo);
}
private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
private void removeDateIfExist(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
//decide whether remove date field from where
Environment environment = ContextUtils.getBean(Environment.class);
String correctorDate = environment.getProperty("s2.corrector.date");
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) {
Set<String> removeFieldNames = new HashSet<>();
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
private boolean checkIfNameInWhereFields(Set<SchemaElement> dims, List<String> whereFields) {
for (SchemaElement element : dims) {
if (whereFields.contains(element.getName())) {
return true;
}
}
return false;
Set<String> removeFieldNames = new HashSet<>();
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
//decide whether add date field to where
Environment environment = ContextUtils.getBean(Environment.class);
String correctorDate = environment.getProperty("s2.corrector.date");
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) {
return;
}
Long dataSetId = semanticParseInfo.getDataSetId();
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
boolean isDateInWhere = checkIfNameInWhereFields(dataSetSchema.getDimensions(), whereFields);
if (isDateInWhere) {
return;
}
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId,
semanticParseInfo.getQueryType());
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext,
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
if (StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight())) {
if (isValidDateRange(startEndDate)) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String dateChName = TimeDimensionEnum.DAY.getChName();
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
startEndDate.getLeft(), dateChName, startEndDate.getRight());
try {
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
} catch (JSQLParserException e) {
log.error("parseCondExpression:{}", e);
}
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )",
dateChName, startEndDate.getLeft(), dateChName, startEndDate.getRight());
correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
}
}
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
private boolean containsPartitionDimensions(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
return dataSetSchema.containsPartitionDimensions();
}
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
if (Objects.isNull(dateBoundInfo)) {
return;
}
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
if (dateBoundInfo != null
&& StringUtils.isBlank(dateBoundInfo.getLowerBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
String upperDate = dateBoundInfo.getUpperDate();
try {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
private boolean isValidDateRange(Pair<String, String> startEndDate) {
return StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight());
}
private String addConditionToSQL(String sql, String condition) {
try {
Expression expression = CCJSqlParserUtil.parseCondExpression(condition);
return SqlAddHelper.addWhere(sql, expression);
} catch (JSQLParserException e) {
log.error("addConditionToSQL:{}", e);
return sql;
}
}
}

View File

@@ -4,12 +4,16 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.utils.QueryFilterParser;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
@@ -17,11 +21,6 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/**
* Perform SQL corrections on the "Where" section in S2SQL.
*/
@@ -30,27 +29,23 @@ public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
addQueryFilter(chatQueryContext, semanticParseInfo);
updateFieldValueByTechName(chatQueryContext, semanticParseInfo);
}
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
Expression expression = null;
try {
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
@@ -69,49 +64,35 @@ public class WhereCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(dimensions)) {
return;
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
String correctedS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
String replaceSql = SqlReplaceHelper.replaceValue(correctedS2SQL, aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
}
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
if (CollectionUtils.isEmpty(dimensions)) {
return new HashMap<>();
}
Map<String, Map<String, String>> result = new HashMap<>();
for (SchemaElement dimension : dimensions) {
if (Objects.isNull(dimension)
|| StringUtils.isEmpty(dimension.getName())
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
continue;
}
String name = dimension.getName();
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
if (Objects.isNull(valueMap) || StringUtils.isEmpty(valueMap.getTechName())) {
continue;
}
if (StringUtils.isNotEmpty(valueMap.getBizName())) {
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
}
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
valueMap.getAlias().stream().forEach(alias -> {
if (StringUtils.isNotEmpty(alias)) {
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
}
});
}
}
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
result.put(name, aliasAndBizNameToTechName);
}
}
return result;
return dimensions.stream()
.filter(dimension -> Objects.nonNull(dimension)
&& StringUtils.isNotEmpty(dimension.getName())
&& !CollectionUtils.isEmpty(dimension.getSchemaValueMaps()))
.collect(Collectors.toMap(
SchemaElement::getName,
dimension -> dimension.getSchemaValueMaps().stream()
.filter(valueMap -> Objects.nonNull(valueMap)
&& StringUtils.isNotEmpty(valueMap.getTechName()))
.flatMap(valueMap -> {
Map<String, String> map = new HashMap<>();
if (StringUtils.isNotEmpty(valueMap.getBizName())) {
map.put(valueMap.getBizName(), valueMap.getTechName());
}
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
valueMap.getAlias().stream()
.filter(StringUtils::isNotEmpty)
.forEach(alias -> map.put(alias, valueMap.getTechName()));
}
return map.entrySet().stream();
})
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
));
}
}