(improvement)(Headless) Add AggCorrector to check and add aggregate functions (#1098)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-06-06 10:24:30 +08:00
committed by GitHub
parent 2ed7762210
commit 357e2c03eb
3 changed files with 107 additions and 9 deletions

View File

@@ -0,0 +1,32 @@
package com.tencent.supersonic.headless.core.chat.corrector;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
/**
* Verify whether the SQL aggregate function is missing. If it is missing, fill it in.
*/
@Slf4j
public class AggCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
addAggregate(queryContext, semanticParseInfo);
}
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}
addAggregateToMetric(queryContext, semanticParseInfo);
}
}

View File

@@ -84,15 +84,6 @@ public class GroupByCorrector extends BaseSemanticCorrector {
})
.collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
addAggregate(queryContext, semanticParseInfo);
}
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}
addAggregateToMetric(queryContext, semanticParseInfo);
}
}

View File

@@ -0,0 +1,75 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.core.chat.corrector.AggCorrector;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static org.testng.Assert.assertEquals;
class AggCorrectorTest {
@Test
void testDoCorrect() {
AggCorrector corrector = new AggCorrector();
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SchemaElement dataSet = new SchemaElement();
dataSet.setDataSet(dataSetId);
semanticParseInfo.setDataSet(dataSet);
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(queryContext, semanticParseInfo);
assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
}
private QueryContext buildQueryContext(Long dataSetId) {
QueryContext queryContext = new QueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
QueryConfig queryConfig = new QueryConfig();
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSet(1L);
element1.setName("部门");
dimensions.add(element1);
dataSetSchema.setDimensions(dimensions);
Set<SchemaElement> metrics = new HashSet<>();
SchemaElement metric1 = new SchemaElement();
metric1.setDataSet(1L);
metric1.setName("访问次数");
metrics.add(metric1);
dataSetSchema.setMetrics(metrics);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}
}