(improvement)(chat) add addAggregateToMetric in GlobalAfterCorrector and fix getAgg null (#152)

This commit is contained in:
lexluo09
2023-09-27 12:55:59 +08:00
committed by GitHub
parent 24e8e756de
commit df7fea9ee3
4 changed files with 51 additions and 38 deletions

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
@@ -13,6 +14,7 @@ import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -59,4 +61,27 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
semanticCorrectInfo.setSql(replaceFields); semanticCorrectInfo.setSql(replaceFields);
} }
protected void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
//add aggregate to all metric
String sql = semanticCorrectInfo.getSql();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<String, String> metricToAggregate = semanticSchema.getMetrics(modelId).stream()
.map(schemaElement -> {
if (Objects.isNull(schemaElement.getDefaultAgg())) {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
}
return schemaElement;
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;
}
String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate);
semanticCorrectInfo.setSql(aggregateSql);
}
} }

View File

@@ -14,6 +14,7 @@ public class GlobalAfterCorrector extends BaseSemanticCorrector {
public void correct(SemanticCorrectInfo semanticCorrectInfo) { public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo); super.correct(semanticCorrectInfo);
addAggregateToMetric(semanticCorrectInfo);
String sql = semanticCorrectInfo.getSql(); String sql = semanticCorrectInfo.getSql();
if (!SqlParserSelectHelper.hasAggregateFunction(sql)) { if (!SqlParserSelectHelper.hasAggregateFunction(sql)) {
return; return;

View File

@@ -1,16 +1,12 @@
package com.tencent.supersonic.chat.corrector; package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@@ -36,28 +32,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
addAggregateToMetric(semanticCorrectInfo); addAggregateToMetric(semanticCorrectInfo);
} }
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
//add aggregate to all metric
String sql = semanticCorrectInfo.getSql();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<String, String> metricToAggregate = semanticSchema.getMetrics(modelId).stream()
.map(schemaElement -> {
if (Objects.isNull(schemaElement.getDefaultAgg())) {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
}
return schemaElement;
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;
}
String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate);
semanticCorrectInfo.setSql(aggregateSql);
}
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql()); String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());

View File

@@ -23,6 +23,7 @@ import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@@ -103,6 +104,7 @@ public class CatalogImpl implements Catalog {
@Override @Override
public String getAgg(Long modelId, String metricBizName) { public String getAgg(Long modelId, String metricBizName) {
try {
List<MetricResp> metricResps = getMetrics(modelId); List<MetricResp> metricResps = getMetrics(modelId);
if (!CollectionUtils.isEmpty(metricResps)) { if (!CollectionUtils.isEmpty(metricResps)) {
Optional<MetricResp> metric = metricResps.stream() Optional<MetricResp> metric = metricResps.stream()
@@ -113,13 +115,23 @@ public class CatalogImpl implements Catalog {
if (!CollectionUtils.isEmpty(measureRespList)) { if (!CollectionUtils.isEmpty(measureRespList)) {
String measureName = metric.get().getTypeParams().getMeasures().get(0).getBizName(); String measureName = metric.get().getTypeParams().getMeasures().get(0).getBizName();
Optional<MeasureResp> measure = measureRespList.stream() Optional<MeasureResp> measure = measureRespList.stream()
.filter(m -> m.getBizName().equalsIgnoreCase(measureName)).findFirst(); .filter(Objects::nonNull)
.filter(m -> {
if (StringUtils.isNotEmpty(m.getBizName())) {
return m.getBizName().equalsIgnoreCase(measureName);
}
return false;
})
.findFirst();
if (measure.isPresent()) { if (measure.isPresent()) {
return measure.get().getAgg(); return measure.get().getAgg();
} }
} }
} }
} }
} catch (Exception e) {
log.error("getAgg:", e);
}
return ""; return "";
} }
} }