From 3ef3c4427765a0f3bc0f4f4744c01ff9c204761a Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sat, 6 Apr 2024 09:27:59 +0800 Subject: [PATCH] (improvement)(headless) If all the fields in the select/where statement are of tag type. (#886) --- .../core/chat/parser/QueryTypeParser.java | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/QueryTypeParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/QueryTypeParser.java index 5d0f793a9..2a86b69b0 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/QueryTypeParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/QueryTypeParser.java @@ -13,10 +13,12 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.core.pojo.ChatContext; import com.tencent.supersonic.headless.core.pojo.QueryContext; + import java.util.List; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -52,35 +54,51 @@ public class QueryTypeParser implements SemanticParser { Long dataSetId = parseInfo.getDataSetId(); SemanticSchema semanticSchema = queryContext.getSemanticSchema(); if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) { - //If all the fields in the SELECT statement are of tag type. - List whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL()) - .stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) - .collect(Collectors.toList()); - - if (CollectionUtils.isNotEmpty(whereFields)) { + List whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL()); + List whereFilterByTimeFields = filterByTimeFields(whereFields); + if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) { Set ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName) .collect(Collectors.toSet()); - if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) { + if (CollectionUtils.isNotEmpty(ids) && ids.stream() + .anyMatch(whereFilterByTimeFields::contains)) { return QueryType.ID; } + } + List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); + selectFields.addAll(whereFields); + List selectWhereFilterByTimeFields = filterByTimeFields(whereFields); + if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) { Set tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName) .collect(Collectors.toSet()); - if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) { + //If all the fields in the SELECT/WHERE statement are of tag type. + if (CollectionUtils.isNotEmpty(tags) + && tags.containsAll(selectWhereFilterByTimeFields)) { return QueryType.TAG; } } } //2. metric queryType - List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); - List metrics = semanticSchema.getMetrics(dataSetId); - if (CollectionUtils.isNotEmpty(metrics)) { - Set metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet()); - boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains); - if (containMetric) { - return QueryType.METRIC; - } + if (selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) { + return QueryType.METRIC; } return QueryType.ID; } + private static List filterByTimeFields(List whereFields) { + List selectAndWhereFilterByTimeFields = whereFields + .stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) + .collect(Collectors.toList()); + return selectAndWhereFilterByTimeFields; + } + + private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) { + List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); + List metrics = semanticSchema.getMetrics(dataSetId); + if (CollectionUtils.isNotEmpty(metrics)) { + Set metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet()); + return selectFields.stream().anyMatch(metricNameSet::contains); + } + return false; + } + }