mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
[improvement][Chat] Add TimeCorrector and rename the associated SqlParserHelper. (#707)
This commit is contained in:
@@ -6,9 +6,9 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -75,14 +75,14 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
}
|
||||
|
||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||
|
||||
// If there is no aggregate function in the S2SQL statement and
|
||||
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||
.collect(Collectors.toSet());
|
||||
@@ -94,7 +94,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
}
|
||||
|
||||
needAddFields.removeAll(selectFields);
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
// check if has distinct
|
||||
boolean hasDistinct = SqlParserSelectHelper.hasDistinct(correctS2SQL);
|
||||
boolean hasDistinct = SqlSelectHelper.hasDistinct(correctS2SQL);
|
||||
if (hasDistinct) {
|
||||
log.info("not add group by ,exist distinct in correctS2SQL:{}", correctS2SQL);
|
||||
return;
|
||||
@@ -54,7 +54,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
@@ -63,12 +63,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
return;
|
||||
}
|
||||
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(field -> {
|
||||
@@ -78,13 +78,13 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
|
||||
addAggregate(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
|
||||
@@ -3,9 +3,9 @@ package com.tencent.supersonic.chat.core.corrector;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -42,18 +42,18 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
||||
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
|
||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -19,7 +19,7 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2QL.
|
||||
* Perform schema corrections on the Schema information in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
@@ -41,20 +41,20 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -16,8 +16,8 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& !CollectionUtils.isEmpty(selectFields)
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the time in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class TimeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addLowerBoundDate(semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||
if (Objects.isNull(dateBoundInfo)) {
|
||||
return;
|
||||
}
|
||||
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
|
||||
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
|
||||
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
|
||||
String upperDate = dateBoundInfo.getUpperDate();
|
||||
try {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -10,9 +10,9 @@ import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -38,8 +38,6 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
addDateIfNotExist(queryContext, semanticParseInfo);
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addQueryFilter(queryContext, semanticParseInfo);
|
||||
|
||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||
@@ -58,25 +56,19 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getViewId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlAddHelper.addWhere(
|
||||
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
}
|
||||
}
|
||||
@@ -107,7 +99,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -54,7 +54,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -72,7 +72,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlEqualHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -52,7 +52,7 @@ public class LLMResponseService {
|
||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
}
|
||||
result.put(key, entry.getValue());
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class TimeCorrectorTest {
|
||||
|
||||
@Test
|
||||
void testDoCorrect() {
|
||||
TimeCorrector corrector = new TimeCorrector();
|
||||
QueryContext queryContext = new QueryContext();
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
//1.数据日期 <=
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//2.数据日期 <
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//3.数据日期 >=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//4.数据日期 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//5.no 数据日期
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//6. 数据日期-月 <=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
|
||||
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//7. 数据日期-月 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//8. no where
|
||||
sql = "SELECT COUNT(1) FROM 数据库";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user