mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(Headless) Add AggCorrector to check and add aggregate functions (#1098)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Verify whether the SQL aggregate function is missing. If it is missing, fill it in.
|
||||
*/
|
||||
@Slf4j
|
||||
public class AggCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
addAggregate(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
addAggregateToMetric(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -84,15 +84,6 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
addAggregate(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
addAggregateToMetric(queryContext, semanticParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.AggCorrector;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
class AggCorrectorTest {
|
||||
|
||||
@Test
|
||||
void testDoCorrect() {
|
||||
AggCorrector corrector = new AggCorrector();
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SchemaElement dataSet = new SchemaElement();
|
||||
dataSet.setDataSet(dataSetId);
|
||||
semanticParseInfo.setDataSet(dataSet);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
|
||||
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
|
||||
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(queryContext, semanticParseInfo);
|
||||
assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
|
||||
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
|
||||
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long dataSetId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSet(1L);
|
||||
element1.setName("部门");
|
||||
dimensions.add(element1);
|
||||
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
SchemaElement metric1 = new SchemaElement();
|
||||
metric1.setDataSet(1L);
|
||||
metric1.setName("访问次数");
|
||||
metrics.add(metric1);
|
||||
|
||||
dataSetSchema.setMetrics(metrics);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user