(improvement)(chat) rename logicSql to correctS2SQL and the rule does not perform corrector operation (#373)

This commit is contained in:
lexluo09
2023-11-13 15:36:27 +08:00
committed by GitHub
parent cdb84716b7
commit 0f02e21eaa
8 changed files with 56 additions and 50 deletions

View File

@@ -71,10 +71,10 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return result;
}
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String logicSql) {
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(logicSql));
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(logicSql));
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(logicSql));
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
return;
@@ -82,13 +82,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
needAddFields.removeAll(selectFields);
needAddFields.remove(TimeDimensionEnum.DAY.getChName());
String replaceFields = SqlParserAddHelper.addFieldsToSelect(logicSql, new ArrayList<>(needAddFields));
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
}
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
Long modelId = semanticParseInfo.getModel().getModel();
List<SchemaElement> metrics = getMetricElements(modelId);
@@ -104,7 +104,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;
}
String aggregateSql = SqlParserAddHelper.addAggregateToField(logicSql, metricToAggregate);
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
}

View File

@@ -31,7 +31,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String logicSql = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
//add alias field name
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
@@ -47,7 +47,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return;
@@ -56,12 +56,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
return;
}
if (SqlParserSelectHelper.hasGroupBy(logicSql)) {
log.info("not add group by ,exist group by in logicSql:{}", logicSql);
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
return;
}
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
Set<String> groupByFields = selectFields.stream()
.filter(field -> dimensions.contains(field))
.filter(field -> {
@@ -71,7 +71,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return true;
})
.collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(logicSql, groupByFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
addAggregate(semanticParseInfo);
}

View File

@@ -44,13 +44,13 @@ public class HavingCorrector extends BaseSemanticCorrector {
}
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
return;
}
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
}
return;

View File

@@ -12,15 +12,15 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
if (!CollectionUtils.isEmpty(aggregateFields)
&& !CollectionUtils.isEmpty(selectFields)
&& aggregateFields.size() == selectFields.size()) {
return;
}
addFieldsToSelect(semanticParseInfo, logicSql);
addFieldsToSelect(semanticParseInfo, correctS2SQL);
}
}

View File

@@ -46,38 +46,38 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to logicSql :{}", queryFilter);
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
Expression expression = null;
try {
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
logicSql = SqlParserAddHelper.addWhere(logicSql, expression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
}
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
logicSql = SqlParserReplaceHelper.replaceFunction(logicSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(logicSql);
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getChName())) {
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
if (StringUtils.isNotBlank(currentDate)) {
logicSql = SqlParserAddHelper.addParenthesisToWhere(logicSql);
logicSql = SqlParserAddHelper.addWhere(logicSql, TimeDimensionEnum.DAY.getChName(), currentDate);
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
}
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
private String getQueryFilter(QueryFilters queryFilters) {
@@ -106,9 +106,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
return;
}

View File

@@ -41,26 +41,24 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
public String explain(User user) {
ExplainSqlReq explainSqlReq = null;
try {
ExplainResp explain = null;
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (StringUtils.isNotBlank(sqlInfo.getCorrectS2SQL())) {
//sql
QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(sqlInfo.getCorrectS2SQL(),
parseInfo.getModelId());
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(queryS2SQLReq)
.queryReq(QueryReqBuilder.buildS2SQLReq(
sqlInfo.getCorrectS2SQL(), parseInfo.getModelId()
))
.build();
explain = semanticInterpreter.explain(explainSqlReq, user);
} else {
//struct
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.STRUCT)
.queryReq(queryStructReq)
.queryReq(QueryReqBuilder.buildStructReq(parseInfo))
.build();
explain = semanticInterpreter.explain(explainSqlReq, user);
}
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
if (Objects.nonNull(explain)) {
return explain.getSql();
}

View File

@@ -38,12 +38,16 @@ public class ParserInfoServiceImpl implements ParseInfoService {
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String logicSql = sqlInfo.getCorrectS2SQL();
if (StringUtils.isBlank(logicSql)) {
String correctS2SQL = sqlInfo.getCorrectS2SQL();
if (StringUtils.isBlank(correctS2SQL)) {
return;
}
// if S2SQL equals correctS2SQL, than not update the parseInfo.
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
return;
}
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
//set dataInfo
try {
if (!CollectionUtils.isEmpty(expressions)) {
@@ -70,10 +74,9 @@ public class ParserInfoServiceImpl implements ParseInfoService {
if (Objects.isNull(semanticSchema)) {
return;
}
//cannot use metrics in sql to override parse info
//List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
//Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
//parseInfo.setMetrics(metrics);
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
parseInfo.setNativeQuery(false);

View File

@@ -29,6 +29,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QueryRanker;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import com.tencent.supersonic.chat.service.ChatService;
@@ -148,6 +149,10 @@ public class QueryServiceImpl implements QueryService {
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
semanticQuery.initS2Sql(queryReq.getUser());
// rule
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}
semanticCorrectors.stream().forEach(correction -> {
correction.correct(queryReq, semanticQuery.getParseInfo());
});