(improvement)(chat) add default aggregate to all metric and add group by to dimension and add metric filter in having (#150)

This commit is contained in:
lexluo09
2023-09-27 00:05:45 +08:00
committed by GitHub
parent ff5479f1a2
commit 24e8e756de
18 changed files with 327 additions and 85 deletions

View File

@@ -15,6 +15,7 @@ import lombok.NoArgsConstructor;
@AllArgsConstructor
@NoArgsConstructor
public class SchemaElement implements Serializable {
private Long model;
private Long id;
private String name;
@@ -26,6 +27,8 @@ public class SchemaElement implements Serializable {
private List<SchemaValueMap> schemaValueMaps;
private String defaultAgg;
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@@ -5,13 +5,18 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector {
@@ -38,4 +43,20 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return result;
}
protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) {
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return;
}
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
semanticCorrectInfo.setSql(replaceFields);
}
}

View File

@@ -0,0 +1,29 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
@Slf4j
public class GlobalAfterCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
String sql = semanticCorrectInfo.getSql();
if (!SqlParserSelectHelper.hasAggregateFunction(sql)) {
return;
}
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
semanticCorrectInfo.setSql(replaceSql);
}
return;
}
}

View File

@@ -1,13 +1,16 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -17,10 +20,11 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class GlobalCorrector extends BaseSemanticCorrector {
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
replaceAlias(semanticCorrectInfo);
@@ -33,12 +37,26 @@ public class GlobalCorrector extends BaseSemanticCorrector {
}
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
//add aggregate to all metric
String sql = semanticCorrectInfo.getSql();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
if (SqlParserSelectHelper.hasGroupBy(semanticCorrectInfo.getSql())) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<String, String> metricToAggregate = semanticSchema.getMetrics(modelId).stream()
.map(schemaElement -> {
if (Objects.isNull(schemaElement.getDefaultAgg())) {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
}
return schemaElement;
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;
}
String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate);
semanticCorrectInfo.setSql(aggregateSql);
}
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {

View File

@@ -1,7 +1,17 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class GroupByCorrector extends BaseSemanticCorrector {
@@ -9,7 +19,25 @@ public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
//add aggregate to all metric
String sql = semanticCorrectInfo.getSql();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
.filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName()))
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return;
}
Set<String> groupByFields = selectFields.stream().filter(field -> dimensions.contains(field))
.collect(Collectors.toSet());
semanticCorrectInfo.setSql(SqlParserUpdateHelper.addGroupBy(sql, groupByFields));
}
}

View File

@@ -1,7 +1,14 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class HavingCorrector extends BaseSemanticCorrector {
@@ -9,6 +16,22 @@ public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
//add aggregate to all metric
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) {
return;
}
String havingSql = SqlParserUpdateHelper.addHaving(semanticCorrectInfo.getSql(), metrics);
semanticCorrectInfo.setSql(havingSql);
}
}

View File

@@ -1,16 +1,7 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
@Slf4j
public class SelectCorrector extends BaseSemanticCorrector {
@@ -19,28 +10,6 @@ public class SelectCorrector extends BaseSemanticCorrector {
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
String sql = semanticCorrectInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(sql)) {
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
semanticCorrectInfo.setSql(replaceSql);
}
return;
}
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return;
}
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
semanticCorrectInfo.setSql(replaceFields);
addFieldsToSelect(semanticCorrectInfo, sql);
}
}

View File

@@ -128,10 +128,7 @@ public class LLMDslParser implements SemanticParser {
public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
String correctorSql = semanticCorrectInfo.getPreSql();
if (StringUtils.isEmpty(correctorSql)) {
correctorSql = semanticCorrectInfo.getSql();
}
String correctorSql = semanticCorrectInfo.getSql();
parseInfo.getSqlInfo().setLogicSql(correctorSql);
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
//set dataInfo

View File

@@ -120,7 +120,7 @@ public class SemanticService {
return entityInfo;
} catch (Exception e) {
log.error("setMaintModel error {}", e);
log.error("setMainModel error {}", e);
}
}
}

View File

@@ -2,19 +2,19 @@ package com.tencent.supersonic.knowledge.semantic;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
@@ -22,7 +22,6 @@ import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.query.service.QueryService;
import com.tencent.supersonic.semantic.query.service.SchemaService;
import java.util.List;
@@ -33,7 +32,6 @@ import lombok.extern.slf4j.Slf4j;
public class LocalSemanticLayer extends BaseSemanticLayer {
private SchemaService schemaService;
private ModelService modelService;
private DimensionService dimensionService;
private MetricService metricService;
private QueryService queryService;

View File

@@ -53,6 +53,7 @@ public class ModelSchemaBuilder {
.type(SchemaElementType.METRIC)
.useCnt(metric.getUseCnt())
.alias(alias)
.defaultAgg(metric.getDefaultAgg())
.build();
metrics.add(metricToAdd);