(improvement)(Headless) If it is in DETAIL mode and select *, add default metrics and dimensions. (#1186)

This commit is contained in:
lexluo09
2024-06-22 01:28:52 +08:00
committed by GitHub
parent e293be3ebf
commit cfde267a06
8 changed files with 271 additions and 50 deletions

View File

@@ -1,20 +1,12 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@@ -22,6 +14,10 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
/**
* basic semantic correction functionality, offering common methods and an
@@ -75,27 +71,6 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return result;
}
protected String addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
//decide whether add order by expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
}
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
return correctS2SQL;
}
needAddFields.removeAll(selectFields);
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
return addFieldsToSelectSql;
}
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();

View File

@@ -1,11 +1,24 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.QueryContext;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.List;
@@ -27,8 +40,71 @@ public class SelectCorrector extends BaseSemanticCorrector {
&& aggregateFields.size() == selectFields.size()) {
return;
}
correctS2SQL = addFieldsToSelect(semanticParseInfo, correctS2SQL);
correctS2SQL = addFieldsToSelect(queryContext, semanticParseInfo, correctS2SQL);
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
}
protected String addFieldsToSelect(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
String correctS2SQL) {
correctS2SQL = addTagDefaultFields(queryContext, semanticParseInfo, correctS2SQL);
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
//decide whether add order by expression field to select
String correctorAdditionalInfo = getAdditionalInfo();
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
}
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
return correctS2SQL;
}
needAddFields.removeAll(selectFields);
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
return addFieldsToSelectSql;
}
private String addTagDefaultFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
String correctS2SQL) {
//If it is in DETAIL mode and select *, add default metrics and dimensions.
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
return correctS2SQL;
}
Long dataSetId = semanticParseInfo.getDataSetId();
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
Set<String> needAddDefaultFields = new HashSet<>();
if (Objects.nonNull(dataSetSchema)) {
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
Set<String> metrics = dataSetSchema.getTagDefaultMetrics()
.stream().map(schemaElement -> schemaElement.getName())
.collect(Collectors.toSet());
needAddDefaultFields.addAll(metrics);
}
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultDimensions())) {
Set<String> dimensions = dataSetSchema.getTagDefaultDimensions()
.stream().map(schemaElement -> schemaElement.getName())
.collect(Collectors.toSet());
needAddDefaultFields.addAll(dimensions);
}
}
// remove * in sql and add default fields.
if (!CollectionUtils.isEmpty(needAddDefaultFields)) {
correctS2SQL = SqlRemoveHelper.removeAsteriskAndAddFields(correctS2SQL, needAddDefaultFields);
}
return correctS2SQL;
}
private String getAdditionalInfo() {
String correctorAdditionalInfo = null;
try {
Environment environment = ContextUtils.getBean(Environment.class);
correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
} catch (Exception e) {
log.error("getAdditionalInfo error:{}", e);
}
return correctorAdditionalInfo;
}
}

View File

@@ -0,0 +1,102 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.chat.QueryContext;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class SelectCorrectorTest {
Long dataSetId = 2L;
@Test
void testDoCorrect() {
BaseSemanticCorrector corrector = new SelectCorrector();
QueryContext queryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SchemaElement dataSet = new SchemaElement();
dataSet.setDataSet(dataSetId);
semanticParseInfo.setDataSet(dataSet);
semanticParseInfo.setQueryType(QueryType.DETAIL);
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(queryContext, semanticParseInfo);
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
}
private QueryContext buildQueryContext(Long dataSetId) {
QueryContext queryContext = new QueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
QueryConfig queryConfig = new QueryConfig();
TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig();
DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo();
List<Long> dimensionIds = new ArrayList<>();
dimensionIds.add(1L);
dimensionIds.add(2L);
dimensionIds.add(3L);
defaultDisplayInfo.setDimensionIds(dimensionIds);
List<Long> metricIds = new ArrayList<>();
metricIds.add(4L);
defaultDisplayInfo.setMetricIds(metricIds);
tagTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo);
queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig);
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSet(dataSetId);
element1.setId(1L);
element1.setName("艺人名");
dimensions.add(element1);
SchemaElement element2 = new SchemaElement();
element2.setDataSet(dataSetId);
element2.setId(2L);
element2.setName("性别");
dimensions.add(element2);
SchemaElement element3 = new SchemaElement();
element3.setDataSet(dataSetId);
element3.setId(3L);
element3.setName("国籍");
dimensions.add(element3);
dataSetSchema.setDimensions(dimensions);
Set<SchemaElement> metrics = new HashSet<>();
SchemaElement metric1 = new SchemaElement();
metric1.setDataSet(dataSetId);
metric1.setId(4L);
metric1.setName("粉丝数");
metrics.add(metric1);
dataSetSchema.setMetrics(metrics);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}
}