mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
(improvement)[chat] Skip the corrector for complex SQL, and do not add the HAVING field to the SELECT clause (#1754)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -8,12 +9,18 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/** Verify whether the SQL aggregate function is missing. If it is missing, fill it in. */
|
||||
/**
|
||||
* Verify whether the SQL aggregate function is missing. If it is missing, fill it in.
|
||||
*/
|
||||
@Slf4j
|
||||
public class AggCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
addAggregate(chatQueryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -24,6 +25,10 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
Boolean needAddGroupBy = needAddGroupBy(chatQueryContext, semanticParseInfo);
|
||||
if (!needAddGroupBy) {
|
||||
return;
|
||||
|
||||
@@ -3,14 +3,11 @@ package com.tencent.supersonic.headless.chat.corrector;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
@@ -23,18 +20,8 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
// add aggregate to all metric
|
||||
addHaving(chatQueryContext, semanticParseInfo);
|
||||
|
||||
// decide whether add having expression field to select
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo =
|
||||
environment.getProperty("s2.corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo)
|
||||
&& Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
@@ -30,6 +31,9 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (SqlValidHelper.isComplexSQL(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add
|
||||
|
||||
@@ -274,9 +274,9 @@ public class QueryReqConverter {
|
||||
String sql = viewQueryParam.getSql();
|
||||
for (MetricTable metricTable : viewQueryParam.getTables()) {
|
||||
Set<String> measures = new HashSet<>();
|
||||
Map<String, String> replaces = new HashMap<>();
|
||||
generateDerivedMetric(semanticSchemaResp, aggOption, metricTable.getMetrics(),
|
||||
metricTable.getDimensions(), measures, replaces);
|
||||
Map<String, String> replaces = generateDerivedMetric(semanticSchemaResp, aggOption,
|
||||
metricTable.getMetrics(), metricTable.getDimensions(), measures);
|
||||
|
||||
if (!CollectionUtils.isEmpty(replaces)) {
|
||||
// metricTable sql use measures replace metric
|
||||
sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces);
|
||||
@@ -295,49 +295,59 @@ public class QueryReqConverter {
|
||||
viewQueryParam.setSql(sql);
|
||||
}
|
||||
|
||||
private void generateDerivedMetric(SemanticSchemaResp semanticSchemaResp, AggOption aggOption,
|
||||
List<String> metrics, List<String> dimensions, Set<String> measures,
|
||||
Map<String, String> replaces) {
|
||||
private Map<String, String> generateDerivedMetric(SemanticSchemaResp semanticSchemaResp,
|
||||
AggOption aggOption, List<String> metrics, List<String> dimensions,
|
||||
Set<String> measures) {
|
||||
Map<String, String> result = new HashMap<>();
|
||||
List<MetricSchemaResp> metricResps = semanticSchemaResp.getMetrics();
|
||||
List<DimSchemaResp> dimensionResps = semanticSchemaResp.getDimensions();
|
||||
// check metrics has derived
|
||||
if (!metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType
|
||||
.isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()))) {
|
||||
return;
|
||||
|
||||
// Check if any metric is derived
|
||||
boolean hasDerivedMetrics =
|
||||
metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType
|
||||
.isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()));
|
||||
if (!hasDerivedMetrics) {
|
||||
return result;
|
||||
}
|
||||
|
||||
log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics);
|
||||
|
||||
Set<String> allFields = new HashSet<>();
|
||||
Map<String, Measure> allMeasures = new HashMap<>();
|
||||
semanticSchemaResp.getModelResps().forEach(modelResp -> {
|
||||
allFields.addAll(modelResp.getFieldList());
|
||||
if (Objects.nonNull(modelResp.getModelDetail().getMeasures())) {
|
||||
modelResp.getModelDetail().getMeasures().stream()
|
||||
.forEach(mm -> allMeasures.put(mm.getBizName(), mm));
|
||||
if (modelResp.getModelDetail().getMeasures() != null) {
|
||||
modelResp.getModelDetail().getMeasures()
|
||||
.forEach(measure -> allMeasures.put(measure.getBizName(), measure));
|
||||
}
|
||||
});
|
||||
Set<String> deriveDimension = new HashSet<>();
|
||||
Set<String> deriveMetric = new HashSet<>();
|
||||
Map<String, String> visitedMetric = new HashMap<>();
|
||||
if (!CollectionUtils.isEmpty(metricResps)) {
|
||||
for (MetricResp metricResp : metricResps) {
|
||||
if (metrics.contains(metricResp.getBizName())) {
|
||||
if (MetricType.isDerived(metricResp.getMetricDefineType(),
|
||||
metricResp.getMetricDefineByMeasureParams())) {
|
||||
String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields,
|
||||
allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp),
|
||||
metricResp.getMetricDefineType(), aggOption, visitedMetric,
|
||||
deriveMetric, deriveDimension);
|
||||
replaces.put(metricResp.getBizName(), expr);
|
||||
log.debug("derived metric {}->{}", metricResp.getBizName(), expr);
|
||||
} else {
|
||||
measures.add(metricResp.getBizName());
|
||||
}
|
||||
|
||||
Set<String> derivedDimensions = new HashSet<>();
|
||||
Set<String> derivedMetrics = new HashSet<>();
|
||||
Map<String, String> visitedMetrics = new HashMap<>();
|
||||
|
||||
for (MetricResp metricResp : metricResps) {
|
||||
if (metrics.contains(metricResp.getBizName())) {
|
||||
boolean isDerived = MetricType.isDerived(metricResp.getMetricDefineType(),
|
||||
metricResp.getMetricDefineByMeasureParams());
|
||||
if (isDerived) {
|
||||
String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields,
|
||||
allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp),
|
||||
metricResp.getMetricDefineType(), aggOption, visitedMetrics,
|
||||
derivedMetrics, derivedDimensions);
|
||||
result.put(metricResp.getBizName(), expr);
|
||||
log.debug("derived metric {}->{}", metricResp.getBizName(), expr);
|
||||
} else {
|
||||
measures.add(metricResp.getBizName());
|
||||
}
|
||||
}
|
||||
}
|
||||
measures.addAll(deriveMetric);
|
||||
deriveDimension.stream().filter(d -> !dimensions.contains(d))
|
||||
.forEach(d -> dimensions.add(d));
|
||||
|
||||
measures.addAll(derivedMetrics);
|
||||
derivedDimensions.stream().filter(dimension -> !dimensions.contains(dimension))
|
||||
.forEach(dimensions::add);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private String getDefaultModel(SemanticSchemaResp semanticSchemaResp, List<String> dimensions) {
|
||||
|
||||
Reference in New Issue
Block a user