mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
[improvement][Chat] Add TimeCorrector and rename the associated SqlParserHelper. (#707)
This commit is contained in:
@@ -6,9 +6,9 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -75,14 +75,14 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
}
|
||||
|
||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||
|
||||
// If there is no aggregate function in the S2SQL statement and
|
||||
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||
.collect(Collectors.toSet());
|
||||
@@ -94,7 +94,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
}
|
||||
|
||||
needAddFields.removeAll(selectFields);
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
// check if has distinct
|
||||
boolean hasDistinct = SqlParserSelectHelper.hasDistinct(correctS2SQL);
|
||||
boolean hasDistinct = SqlSelectHelper.hasDistinct(correctS2SQL);
|
||||
if (hasDistinct) {
|
||||
log.info("not add group by ,exist distinct in correctS2SQL:{}", correctS2SQL);
|
||||
return;
|
||||
@@ -54,7 +54,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
@@ -63,12 +63,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
return;
|
||||
}
|
||||
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(field -> {
|
||||
@@ -78,13 +78,13 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
|
||||
addAggregate(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
|
||||
@@ -3,9 +3,9 @@ package com.tencent.supersonic.chat.core.corrector;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -42,18 +42,18 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
||||
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
|
||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -19,7 +19,7 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2QL.
|
||||
* Perform schema corrections on the Schema information in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
@@ -41,20 +41,20 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -16,8 +16,8 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& !CollectionUtils.isEmpty(selectFields)
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the time in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class TimeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addLowerBoundDate(semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||
if (Objects.isNull(dateBoundInfo)) {
|
||||
return;
|
||||
}
|
||||
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
|
||||
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
|
||||
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
|
||||
String upperDate = dateBoundInfo.getUpperDate();
|
||||
try {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -10,9 +10,9 @@ import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -38,8 +38,6 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
addDateIfNotExist(queryContext, semanticParseInfo);
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addQueryFilter(queryContext, semanticParseInfo);
|
||||
|
||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||
@@ -58,25 +56,19 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getViewId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlAddHelper.addWhere(
|
||||
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
}
|
||||
}
|
||||
@@ -107,7 +99,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -54,7 +54,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -72,7 +72,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlEqualHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -52,7 +52,7 @@ public class LLMResponseService {
|
||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
}
|
||||
result.put(key, entry.getValue());
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class TimeCorrectorTest {
|
||||
|
||||
@Test
|
||||
void testDoCorrect() {
|
||||
TimeCorrector corrector = new TimeCorrector();
|
||||
QueryContext queryContext = new QueryContext();
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
//1.数据日期 <=
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//2.数据日期 <
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//3.数据日期 >=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//4.数据日期 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//5.no 数据日期
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//6. 数据日期-月 <=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
|
||||
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//7. 数据日期-月 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
|
||||
//8. no where
|
||||
sql = "SELECT COUNT(1) FROM 数据库";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
|
||||
}
|
||||
}
|
||||
@@ -15,8 +15,8 @@ import com.tencent.supersonic.chat.server.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@@ -61,9 +61,9 @@ public class MetricCheckProcessor implements ParseResultProcessor {
|
||||
|
||||
public String processCorrectSql(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
|
||||
String correctSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql);
|
||||
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql);
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctSql);
|
||||
List<String> groupByFields = SqlSelectHelper.getGroupByFields(correctSql);
|
||||
List<String> metricFields = SqlSelectHelper.getAggregateFields(correctSql);
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctSql);
|
||||
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
|
||||
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) {
|
||||
return correctSql;
|
||||
@@ -195,8 +195,8 @@ public class MetricCheckProcessor implements ParseResultProcessor {
|
||||
}
|
||||
|
||||
private boolean checkHasMetric(String correctSql, SemanticSchema semanticSchema) {
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctSql);
|
||||
List<String> aggFields = SqlParserSelectHelper.getAggregateFields(correctSql);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctSql);
|
||||
List<String> aggFields = SqlSelectHelper.getAggregateFields(correctSql);
|
||||
List<String> collect = semanticSchema.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
for (String field : selectFields) {
|
||||
@@ -209,11 +209,11 @@ public class MetricCheckProcessor implements ParseResultProcessor {
|
||||
|
||||
private static String removeFieldInSql(String sql, Set<String> metricToRemove,
|
||||
Set<String> dimensionByToRemove, Set<String> whereFieldsToRemove) {
|
||||
sql = SqlParserRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove);
|
||||
sql = SqlParserRemoveHelper.removeSelect(sql, metricToRemove);
|
||||
sql = SqlParserRemoveHelper.removeSelect(sql, dimensionByToRemove);
|
||||
sql = SqlParserRemoveHelper.removeGroupBy(sql, dimensionByToRemove);
|
||||
sql = SqlParserRemoveHelper.removeNumberFilter(sql);
|
||||
sql = SqlRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove);
|
||||
sql = SqlRemoveHelper.removeSelect(sql, metricToRemove);
|
||||
sql = SqlRemoveHelper.removeSelect(sql, dimensionByToRemove);
|
||||
sql = SqlRemoveHelper.removeGroupBy(sql, dimensionByToRemove);
|
||||
sql = SqlRemoveHelper.removeNumberFilter(sql);
|
||||
return sql;
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -60,7 +60,7 @@ public class ParseInfoProcessor implements ParseResultProcessor {
|
||||
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
|
||||
return;
|
||||
}
|
||||
List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
|
||||
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
|
||||
//set dataInfo
|
||||
try {
|
||||
if (!org.apache.commons.collections.CollectionUtils.isEmpty(expressions)) {
|
||||
@@ -87,15 +87,15 @@ public class ParseInfoProcessor implements ParseResultProcessor {
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
Set<SchemaElement> metrics = getElements(viewId, allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(getElements(viewId, groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else if (QueryType.TAG.equals(parseInfo.getQueryType())) {
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(getElements(viewId, selectDimensions, semanticSchema.getDimensions()));
|
||||
}
|
||||
|
||||
@@ -60,10 +60,10 @@ import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -318,7 +318,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
if (Objects.nonNull(parseInfo.getSqlInfo())
|
||||
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
fields = SqlParserSelectHelper.getAllFields(correctorSql);
|
||||
fields = SqlSelectHelper.getAllFields(correctorSql);
|
||||
}
|
||||
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
|
||||
&& checkMetricReplace(fields, queryData.getMetrics())) {
|
||||
@@ -373,8 +373,8 @@ public class QueryServiceImpl implements QueryService {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
List<FieldExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
|
||||
List<FieldExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
|
||||
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<Expression> addWhereConditions = new ArrayList<>();
|
||||
List<Expression> addHavingConditions = new ArrayList<>();
|
||||
Set<String> removeWhereFieldNames = new HashSet<>();
|
||||
@@ -384,16 +384,16 @@ public class QueryServiceImpl implements QueryService {
|
||||
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
|
||||
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
|
||||
whereExpressionList, addWhereConditions, removeWhereFieldNames);
|
||||
correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
|
||||
correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
|
||||
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
|
||||
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
|
||||
// replace having filter
|
||||
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
|
||||
correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
|
||||
correctorSql = SqlParserRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
|
||||
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
|
||||
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
|
||||
|
||||
correctorSql = SqlParserAddHelper.addWhere(correctorSql, addWhereConditions);
|
||||
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
|
||||
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
|
||||
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
|
||||
log.info("correctorSql after replacing:{}", correctorSql);
|
||||
return correctorSql;
|
||||
}
|
||||
@@ -407,7 +407,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||
if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
|
||||
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
|
||||
correctorSql = SqlParserReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||
}
|
||||
log.info("after replaceMetrics:{}", correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||
|
||||
Reference in New Issue
Block a user