From df7fea9ee387772ccd9baeb08cd62afac99ff89f Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:55:59 +0800 Subject: [PATCH] (improvement)(chat) add addAggregateToMetric in GlobalAfterCorrector and fix getAgg null (#152) --- .../chat/corrector/BaseSemanticCorrector.java | 25 ++++++++++++ .../chat/corrector/GlobalAfterCorrector.java | 1 + .../chat/corrector/GlobalBeforeCorrector.java | 25 ------------ .../model/application/CatalogImpl.java | 38 ++++++++++++------- 4 files changed, 51 insertions(+), 38 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index 270f8126e..3665a5d04 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -4,6 +4,7 @@ import com.tencent.supersonic.chat.api.component.SemanticCorrector; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; @@ -13,6 +14,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; @@ -59,4 +61,27 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); semanticCorrectInfo.setSql(replaceFields); } + + protected void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { + //add aggregate to all metric + String sql = semanticCorrectInfo.getSql(); + Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel(); + + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + + Map metricToAggregate = semanticSchema.getMetrics(modelId).stream() + .map(schemaElement -> { + if (Objects.isNull(schemaElement.getDefaultAgg())) { + schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name()); + } + return schemaElement; + }).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1)); + + if (CollectionUtils.isEmpty(metricToAggregate)) { + return; + } + + String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate); + semanticCorrectInfo.setSql(aggregateSql); + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java index 5d73483b6..ea5d5a01e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalAfterCorrector.java @@ -14,6 +14,7 @@ public class GlobalAfterCorrector extends BaseSemanticCorrector { public void correct(SemanticCorrectInfo semanticCorrectInfo) { super.correct(semanticCorrectInfo); + addAggregateToMetric(semanticCorrectInfo); String sql = semanticCorrectInfo.getSql(); if (!SqlParserSelectHelper.hasAggregateFunction(sql)) { return; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java index 36cd1bb71..82a1f4f02 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java @@ -1,16 +1,12 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; -import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.List; import java.util.Map; import java.util.Objects; @@ -36,28 +32,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { addAggregateToMetric(semanticCorrectInfo); } - private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { - //add aggregate to all metric - String sql = semanticCorrectInfo.getSql(); - Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel(); - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - - Map metricToAggregate = semanticSchema.getMetrics(modelId).stream() - .map(schemaElement -> { - if (Objects.isNull(schemaElement.getDefaultAgg())) { - schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name()); - } - return schemaElement; - }).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1)); - - if (CollectionUtils.isEmpty(metricToAggregate)) { - return; - } - - String aggregateSql = SqlParserUpdateHelper.addAggregateToField(sql, metricToAggregate); - semanticCorrectInfo.setSql(aggregateSql); - } private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql()); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java index e7b517ab4..2612a2698 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java @@ -23,6 +23,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -103,22 +104,33 @@ public class CatalogImpl implements Catalog { @Override public String getAgg(Long modelId, String metricBizName) { - List metricResps = getMetrics(modelId); - if (!CollectionUtils.isEmpty(metricResps)) { - Optional metric = metricResps.stream() - .filter(m -> m.getBizName().equalsIgnoreCase(metricBizName)).findFirst(); - if (metric.isPresent() && Objects.nonNull(metric.get().getTypeParams()) && !CollectionUtils.isEmpty( - metric.get().getTypeParams().getMeasures())) { - List measureRespList = datasourceService.getMeasureListOfModel(modelId); - if (!CollectionUtils.isEmpty(measureRespList)) { - String measureName = metric.get().getTypeParams().getMeasures().get(0).getBizName(); - Optional measure = measureRespList.stream() - .filter(m -> m.getBizName().equalsIgnoreCase(measureName)).findFirst(); - if (measure.isPresent()) { - return measure.get().getAgg(); + try { + List metricResps = getMetrics(modelId); + if (!CollectionUtils.isEmpty(metricResps)) { + Optional metric = metricResps.stream() + .filter(m -> m.getBizName().equalsIgnoreCase(metricBizName)).findFirst(); + if (metric.isPresent() && Objects.nonNull(metric.get().getTypeParams()) && !CollectionUtils.isEmpty( + metric.get().getTypeParams().getMeasures())) { + List measureRespList = datasourceService.getMeasureListOfModel(modelId); + if (!CollectionUtils.isEmpty(measureRespList)) { + String measureName = metric.get().getTypeParams().getMeasures().get(0).getBizName(); + Optional measure = measureRespList.stream() + .filter(Objects::nonNull) + .filter(m -> { + if (StringUtils.isNotEmpty(m.getBizName())) { + return m.getBizName().equalsIgnoreCase(measureName); + } + return false; + }) + .findFirst(); + if (measure.isPresent()) { + return measure.get().getAgg(); + } } } } + } catch (Exception e) { + log.error("getAgg:", e); } return ""; }