mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(chat) add addAggregateToMetric in GlobalAfterCorrector and fix getAgg null (#152)
This commit is contained in:
@@ -4,6 +4,7 @@ import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||||
@@ -13,6 +14,7 @@ import java.util.ArrayList;
|
|||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -59,4 +61,27 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||||
semanticCorrectInfo.setSql(replaceFields);
|
semanticCorrectInfo.setSql(replaceFields);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
//add aggregate to all metric
|
||||||
|
String sql = semanticCorrectInfo.getSql();
|
||||||
|
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||||
|
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
|
||||||
|
Map<String, String> metricToAggregate = semanticSchema.getMetrics(modelId).stream()
|
||||||
|
.map(schemaElement -> {
|
||||||
|
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
||||||
|
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||||
|
}
|
||||||
|
return schemaElement;
|
||||||
|
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate);
|
||||||
|
semanticCorrectInfo.setSql(aggregateSql);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ public class GlobalAfterCorrector extends BaseSemanticCorrector {
|
|||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
|
||||||
super.correct(semanticCorrectInfo);
|
super.correct(semanticCorrectInfo);
|
||||||
|
addAggregateToMetric(semanticCorrectInfo);
|
||||||
String sql = semanticCorrectInfo.getSql();
|
String sql = semanticCorrectInfo.getSql();
|
||||||
if (!SqlParserSelectHelper.hasAggregateFunction(sql)) {
|
if (!SqlParserSelectHelper.hasAggregateFunction(sql)) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -1,16 +1,12 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -36,28 +32,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
|||||||
addAggregateToMetric(semanticCorrectInfo);
|
addAggregateToMetric(semanticCorrectInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
//add aggregate to all metric
|
|
||||||
String sql = semanticCorrectInfo.getSql();
|
|
||||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
|
|
||||||
Map<String, String> metricToAggregate = semanticSchema.getMetrics(modelId).stream()
|
|
||||||
.map(schemaElement -> {
|
|
||||||
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
|
||||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
|
||||||
}
|
|
||||||
return schemaElement;
|
|
||||||
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate);
|
|
||||||
semanticCorrectInfo.setSql(aggregateSql);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
|
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
|
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import java.util.Objects;
|
|||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -103,22 +104,33 @@ public class CatalogImpl implements Catalog {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getAgg(Long modelId, String metricBizName) {
|
public String getAgg(Long modelId, String metricBizName) {
|
||||||
List<MetricResp> metricResps = getMetrics(modelId);
|
try {
|
||||||
if (!CollectionUtils.isEmpty(metricResps)) {
|
List<MetricResp> metricResps = getMetrics(modelId);
|
||||||
Optional<MetricResp> metric = metricResps.stream()
|
if (!CollectionUtils.isEmpty(metricResps)) {
|
||||||
.filter(m -> m.getBizName().equalsIgnoreCase(metricBizName)).findFirst();
|
Optional<MetricResp> metric = metricResps.stream()
|
||||||
if (metric.isPresent() && Objects.nonNull(metric.get().getTypeParams()) && !CollectionUtils.isEmpty(
|
.filter(m -> m.getBizName().equalsIgnoreCase(metricBizName)).findFirst();
|
||||||
metric.get().getTypeParams().getMeasures())) {
|
if (metric.isPresent() && Objects.nonNull(metric.get().getTypeParams()) && !CollectionUtils.isEmpty(
|
||||||
List<MeasureResp> measureRespList = datasourceService.getMeasureListOfModel(modelId);
|
metric.get().getTypeParams().getMeasures())) {
|
||||||
if (!CollectionUtils.isEmpty(measureRespList)) {
|
List<MeasureResp> measureRespList = datasourceService.getMeasureListOfModel(modelId);
|
||||||
String measureName = metric.get().getTypeParams().getMeasures().get(0).getBizName();
|
if (!CollectionUtils.isEmpty(measureRespList)) {
|
||||||
Optional<MeasureResp> measure = measureRespList.stream()
|
String measureName = metric.get().getTypeParams().getMeasures().get(0).getBizName();
|
||||||
.filter(m -> m.getBizName().equalsIgnoreCase(measureName)).findFirst();
|
Optional<MeasureResp> measure = measureRespList.stream()
|
||||||
if (measure.isPresent()) {
|
.filter(Objects::nonNull)
|
||||||
return measure.get().getAgg();
|
.filter(m -> {
|
||||||
|
if (StringUtils.isNotEmpty(m.getBizName())) {
|
||||||
|
return m.getBizName().equalsIgnoreCase(measureName);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
})
|
||||||
|
.findFirst();
|
||||||
|
if (measure.isPresent()) {
|
||||||
|
return measure.get().getAgg();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("getAgg:", e);
|
||||||
}
|
}
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user