mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(headless) If all the fields in the select/where statement are of tag type. (#886)
This commit is contained in:
@@ -13,10 +13,12 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -52,35 +54,51 @@ public class QueryTypeParser implements SemanticParser {
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
|
||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream()
|
||||
.anyMatch(whereFilterByTimeFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
}
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
selectFields.addAll(whereFields);
|
||||
List<String> selectWhereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||
//If all the fields in the SELECT/WHERE statement are of tag type.
|
||||
if (CollectionUtils.isNotEmpty(tags)
|
||||
&& tags.containsAll(selectWhereFilterByTimeFields)) {
|
||||
return QueryType.TAG;
|
||||
}
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
if (containMetric) {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
if (selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
return QueryType.ID;
|
||||
}
|
||||
|
||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||
List<String> selectAndWhereFilterByTimeFields = whereFields
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
return selectAndWhereFilterByTimeFields;
|
||||
}
|
||||
|
||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
return selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user