[improvement][Chat] Add TimeCorrector and rename the associated SqlParserHelper. (#707)

This commit is contained in:
lexluo09
2024-02-01 15:29:07 +08:00
committed by GitHub
parent 2c1c443b3e
commit 491c76368c
43 changed files with 749 additions and 486 deletions

View File

@@ -6,9 +6,9 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -75,14 +75,14 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
}
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
// If there is no aggregate function in the S2SQL statement and
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
.collect(Collectors.toSet());
@@ -94,7 +94,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
}
needAddFields.removeAll(selectFields);
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
}
@@ -123,7 +123,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;
}
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
}

View File

@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
String correctS2SQL = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
// check if has distinct
boolean hasDistinct = SqlParserSelectHelper.hasDistinct(correctS2SQL);
boolean hasDistinct = SqlSelectHelper.hasDistinct(correctS2SQL);
if (hasDistinct) {
log.info("not add group by ,exist distinct in correctS2SQL:{}", correctS2SQL);
return;
@@ -54,7 +54,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return;
@@ -63,12 +63,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
return;
}
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
return;
}
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
Set<String> groupByFields = selectFields.stream()
.filter(field -> dimensions.contains(field))
.filter(field -> {
@@ -78,13 +78,13 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return true;
})
.collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
addAggregate(queryContext, semanticParseInfo);
}
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;

View File

@@ -3,9 +3,9 @@ package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
@@ -42,18 +42,18 @@ public class HavingCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(metrics)) {
return;
}
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
}
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
return;
}
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
if (!CollectionUtils.isEmpty(havingExpressionList)) {
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
}
return;

View File

@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@@ -19,7 +19,7 @@ import java.util.Set;
import java.util.stream.Collectors;
/**
* Perform schema corrections on the Schema information in S2QL.
* Perform schema corrections on the Schema information in S2SQL.
*/
@Slf4j
public class SchemaCorrector extends BaseSemanticCorrector {
@@ -41,20 +41,20 @@ public class SchemaCorrector extends BaseSemanticCorrector {
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
sqlInfo.setCorrectS2SQL(sql);
}
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
sqlInfo.setCorrectS2SQL(replaceAlias);
}
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);
}
@@ -70,7 +70,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
sqlInfo.setCorrectS2SQL(sql);
}
@@ -102,7 +102,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
)));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql);
}
}

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@@ -16,8 +16,8 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
if (!CollectionUtils.isEmpty(aggregateFields)
&& !CollectionUtils.isEmpty(selectFields)

View File

@@ -0,0 +1,57 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
/**
* Perform SQL corrections on the time in S2SQL.
*/
@Slf4j
public class TimeCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
parserDateDiffFunction(semanticParseInfo);
addLowerBoundDate(semanticParseInfo);
}
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
if (Objects.isNull(dateBoundInfo)) {
return;
}
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
String upperDate = dateBoundInfo.getUpperDate();
try {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
}
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
}

View File

@@ -10,9 +10,9 @@ import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
@@ -38,8 +38,6 @@ public class WhereCorrector extends BaseSemanticCorrector {
addDateIfNotExist(queryContext, semanticParseInfo);
parserDateDiffFunction(semanticParseInfo);
addQueryFilter(queryContext, semanticParseInfo);
updateFieldValueByTechName(queryContext, semanticParseInfo);
@@ -58,25 +56,19 @@ public class WhereCorrector extends BaseSemanticCorrector {
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
}
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getViewId());
if (StringUtils.isNotBlank(currentDate)) {
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
correctS2SQL = SqlParserAddHelper.addWhere(
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
correctS2SQL = SqlAddHelper.addWhere(
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
}
}
@@ -107,7 +99,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@@ -54,7 +54,7 @@ public class QueryTypeParser implements SemanticParser {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
//If all the fields in the SELECT statement are of tag type.
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
@@ -72,7 +72,7 @@ public class QueryTypeParser implements SemanticParser {
}
}
//2. metric queryType
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlEqualHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.springframework.stereotype.Service;
@@ -52,7 +52,7 @@ public class LLMResponseService {
Map<String, LLMSqlResp> result = new HashMap<>();
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
String key = entry.getKey();
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
continue;
}
result.put(key, entry.getValue());

View File

@@ -0,0 +1,100 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class TimeCorrectorTest {
@Test
void testDoCorrect() {
TimeCorrector corrector = new TimeCorrector();
QueryContext queryContext = new QueryContext();
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
//1.数据日期 <=
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//2.数据日期 <
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//3.数据日期 >=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//4.数据日期 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//5.no 数据日期
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//6. 数据日期-月 <=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//7. 数据日期-月 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
//8. no where
sql = "SELECT COUNT(1) FROM 数据库";
sqlInfo.setCorrectS2SQL(sql);
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
}
}