diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/AggCorrector.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/AggCorrector.java new file mode 100644 index 000000000..20e70ab6c --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/AggCorrector.java @@ -0,0 +1,32 @@ +package com.tencent.supersonic.headless.core.chat.corrector; + + +import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.core.pojo.QueryContext; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; + +import java.util.List; + +/** + * Verify whether the SQL aggregate function is missing. If it is missing, fill it in. + */ +@Slf4j +public class AggCorrector extends BaseSemanticCorrector { + + @Override + public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + addAggregate(queryContext, semanticParseInfo); + } + + private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + List sqlGroupByFields = SqlSelectHelper.getGroupByFields( + semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + if (CollectionUtils.isEmpty(sqlGroupByFields)) { + return; + } + addAggregateToMetric(queryContext, semanticParseInfo); + } + +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/GroupByCorrector.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/GroupByCorrector.java index 1dc232d7f..3eff0e3cb 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/GroupByCorrector.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/GroupByCorrector.java @@ -84,15 +84,6 @@ public class GroupByCorrector extends BaseSemanticCorrector { }) .collect(Collectors.toSet()); semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields)); - addAggregate(queryContext, semanticParseInfo); } - private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { - List sqlGroupByFields = SqlSelectHelper.getGroupByFields( - semanticParseInfo.getSqlInfo().getCorrectS2SQL()); - if (CollectionUtils.isEmpty(sqlGroupByFields)) { - return; - } - addAggregateToMetric(queryContext, semanticParseInfo); - } } diff --git a/headless/core/src/test/java/com/tencent/supersonic/chat/core/corrector/AggCorrectorTest.java b/headless/core/src/test/java/com/tencent/supersonic/chat/core/corrector/AggCorrectorTest.java new file mode 100644 index 000000000..cc5d5d5ef --- /dev/null +++ b/headless/core/src/test/java/com/tencent/supersonic/chat/core/corrector/AggCorrectorTest.java @@ -0,0 +1,75 @@ +package com.tencent.supersonic.chat.core.corrector; + +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; +import com.tencent.supersonic.headless.api.pojo.QueryConfig; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.SqlInfo; +import com.tencent.supersonic.headless.core.chat.corrector.AggCorrector; +import com.tencent.supersonic.headless.core.pojo.QueryContext; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import static org.testng.Assert.assertEquals; + +class AggCorrectorTest { + + @Test + void testDoCorrect() { + AggCorrector corrector = new AggCorrector(); + Long dataSetId = 1L; + QueryContext queryContext = buildQueryContext(dataSetId); + SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); + SchemaElement dataSet = new SchemaElement(); + dataSet.setDataSet(dataSetId); + semanticParseInfo.setDataSet(dataSet); + SqlInfo sqlInfo = new SqlInfo(); + String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND" + + " datediff('day', 数据日期, '2024-06-04') <= 7" + + " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1"; + sqlInfo.setS2SQL(sql); + sqlInfo.setCorrectS2SQL(sql); + semanticParseInfo.setSqlInfo(sqlInfo); + corrector.correct(queryContext, semanticParseInfo); + assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'" + + " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户" + + " ORDER BY SUM(访问次数) DESC LIMIT 1", + semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + } + + private QueryContext buildQueryContext(Long dataSetId) { + QueryContext queryContext = new QueryContext(); + List dataSetSchemaList = new ArrayList<>(); + DataSetSchema dataSetSchema = new DataSetSchema(); + QueryConfig queryConfig = new QueryConfig(); + dataSetSchema.setQueryConfig(queryConfig); + SchemaElement schemaElement = new SchemaElement(); + schemaElement.setDataSet(dataSetId); + dataSetSchema.setDataSet(schemaElement); + Set dimensions = new HashSet<>(); + SchemaElement element1 = new SchemaElement(); + element1.setDataSet(1L); + element1.setName("部门"); + dimensions.add(element1); + + dataSetSchema.setDimensions(dimensions); + + Set metrics = new HashSet<>(); + SchemaElement metric1 = new SchemaElement(); + metric1.setDataSet(1L); + metric1.setName("访问次数"); + metrics.add(metric1); + + dataSetSchema.setMetrics(metrics); + dataSetSchemaList.add(dataSetSchema); + + SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList); + queryContext.setSemanticSchema(semanticSchema); + return queryContext; + } + +}