diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java index 6f3eb77df..db1adb885 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java @@ -96,6 +96,11 @@ public class SemanticSchema implements Serializable { return getElementsByDataSetId(dataSetId, metrics); } + public List getMetricNames() { + return getMetrics().stream() + .map(SchemaElement::getName).collect(Collectors.toList()); + } + public List getEntities() { List entities = new ArrayList<>(); dataSetSchemaList.stream().forEach(d -> entities.add(d.getEntity())); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/BaseSemanticCorrector.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/BaseSemanticCorrector.java index c371fcb01..0f7db87e1 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/BaseSemanticCorrector.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/BaseSemanticCorrector.java @@ -76,7 +76,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return result; } - protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) { + protected String addFieldsToSelect(QueryContext queryContext, + SemanticParseInfo semanticParseInfo, String correctS2SQL) { Set selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); @@ -89,7 +90,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { // If there is no aggregate function in the S2SQL statement and // there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement. - if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) { + if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL) + && !hasAggFunctionToAdd(queryContext.getSemanticSchema(), needAddFields)) { List whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); List timeChNameList = TimeDimensionEnum.getChNameList(); Set timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field)) @@ -98,12 +100,17 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { } if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) { - return; + return correctS2SQL; } needAddFields.removeAll(selectFields); - String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields)); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields); + String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields)); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql); + return addFieldsToSelectSql; + } + + private boolean hasAggFunctionToAdd(SemanticSchema semanticSchema, Set needAddFields) { + return needAddFields.stream().anyMatch(field -> semanticSchema.getMetricNames().contains(field)); } protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/SelectCorrector.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/SelectCorrector.java index 63d5e25af..05016a341 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/SelectCorrector.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/SelectCorrector.java @@ -27,7 +27,7 @@ public class SelectCorrector extends BaseSemanticCorrector { && aggregateFields.size() == selectFields.size()) { return; } - addFieldsToSelect(semanticParseInfo, correctS2SQL); + correctS2SQL = addFieldsToSelect(queryContext, semanticParseInfo, correctS2SQL); String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql); }