(improvement)[chat] Skip the corrector for complex SQL, and do not add the HAVING field to the SELECT clause (#1754)

This commit is contained in:
lexluo09
2024-10-09 14:38:12 +08:00
committed by GitHub
parent 3ea3c93dc6
commit fc040970b2
9 changed files with 150 additions and 61 deletions

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
@@ -8,12 +9,18 @@ import org.springframework.util.CollectionUtils;
import java.util.List;
/** Verify whether the SQL aggregate function is missing. If it is missing, fill it in. */
/**
* Verify whether the SQL aggregate function is missing. If it is missing, fill it in.
*/
@Slf4j
public class AggCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
return;
}
addAggregate(chatQueryContext, semanticParseInfo);
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -24,6 +25,10 @@ public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
return;
}
Boolean needAddGroupBy = needAddGroupBy(chatQueryContext, semanticParseInfo);
if (!needAddGroupBy) {
return;

View File

@@ -3,14 +3,11 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.List;
@@ -23,18 +20,8 @@ public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
// add aggregate to all metric
addHaving(chatQueryContext, semanticParseInfo);
// decide whether add having expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo =
environment.getProperty("s2.corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo)
&& Boolean.parseBoolean(correctorAdditionalInfo)) {
addHavingToSelect(semanticParseInfo);
}
}
private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
@@ -30,6 +31,9 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
return;
}
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add