mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
(improvement)(chat) add default aggregate to all metric and add group by to dimension and add metric filter in having (#150)
This commit is contained in:
@@ -5,13 +5,18 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
@@ -38,4 +43,20 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return result;
|
||||
}
|
||||
|
||||
protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||
semanticCorrectInfo.setSql(replaceFields);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalAfterCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
if (!SqlParserSelectHelper.hasAggregateFunction(sql)) {
|
||||
return;
|
||||
}
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
|
||||
semanticCorrectInfo.setSql(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,13 +1,16 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -17,10 +20,11 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalCorrector extends BaseSemanticCorrector {
|
||||
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
|
||||
replaceAlias(semanticCorrectInfo);
|
||||
@@ -33,12 +37,26 @@ public class GlobalCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
//add aggregate to all metric
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
|
||||
if (SqlParserSelectHelper.hasGroupBy(semanticCorrectInfo.getSql())) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Map<String, String> metricToAggregate = semanticSchema.getMetrics(modelId).stream()
|
||||
.map(schemaElement -> {
|
||||
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||
}
|
||||
return schemaElement;
|
||||
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
|
||||
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
|
||||
String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate);
|
||||
semanticCorrectInfo.setSql(aggregateSql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
@@ -1,7 +1,17 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
@@ -9,7 +19,25 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
|
||||
//add aggregate to all metric
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
|
||||
.filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName()))
|
||||
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
Set<String> groupByFields = selectFields.stream().filter(field -> dimensions.contains(field))
|
||||
.collect(Collectors.toSet());
|
||||
semanticCorrectInfo.setSql(SqlParserUpdateHelper.addGroupBy(sql, groupByFields));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
@@ -9,6 +16,22 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
|
||||
//add aggregate to all metric
|
||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
|
||||
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlParserUpdateHelper.addHaving(semanticCorrectInfo.getSql(), metrics);
|
||||
semanticCorrectInfo.setSql(havingSql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
@@ -19,28 +10,6 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
super.correct(semanticCorrectInfo);
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
|
||||
if (SqlParserSelectHelper.hasAggregateFunction(sql)) {
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
|
||||
semanticCorrectInfo.setSql(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||
semanticCorrectInfo.setSql(replaceFields);
|
||||
addFieldsToSelect(semanticCorrectInfo, sql);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,10 +128,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
|
||||
public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
|
||||
|
||||
String correctorSql = semanticCorrectInfo.getPreSql();
|
||||
if (StringUtils.isEmpty(correctorSql)) {
|
||||
correctorSql = semanticCorrectInfo.getSql();
|
||||
}
|
||||
String correctorSql = semanticCorrectInfo.getSql();
|
||||
parseInfo.getSqlInfo().setLogicSql(correctorSql);
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
||||
//set dataInfo
|
||||
|
||||
@@ -120,7 +120,7 @@ public class SemanticService {
|
||||
|
||||
return entityInfo;
|
||||
} catch (Exception e) {
|
||||
log.error("setMaintModel error {}", e);
|
||||
log.error("setMainModel error {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user