(improvement)(headless) If all the fields in the select/where statement are of tag type. (#886)

This commit is contained in:
lexluo09
2024-04-06 09:27:59 +08:00
committed by GitHub
parent 2530407512
commit 3ef3c44277

View File

@@ -13,10 +13,12 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@@ -52,35 +54,51 @@ public class QueryTypeParser implements SemanticParser {
Long dataSetId = parseInfo.getDataSetId();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
//If all the fields in the SELECT statement are of tag type.
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(whereFields)) {
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
if (CollectionUtils.isNotEmpty(ids) && ids.stream()
.anyMatch(whereFilterByTimeFields::contains)) {
return QueryType.ID;
}
}
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
selectFields.addAll(whereFields);
List<String> selectWhereFilterByTimeFields = filterByTimeFields(whereFields);
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
//If all the fields in the SELECT/WHERE statement are of tag type.
if (CollectionUtils.isNotEmpty(tags)
&& tags.containsAll(selectWhereFilterByTimeFields)) {
return QueryType.TAG;
}
}
}
//2. metric queryType
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
if (containMetric) {
return QueryType.METRIC;
}
if (selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) {
return QueryType.METRIC;
}
return QueryType.ID;
}
private static List<String> filterByTimeFields(List<String> whereFields) {
List<String> selectAndWhereFilterByTimeFields = whereFields
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
return selectAndWhereFilterByTimeFields;
}
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
return selectFields.stream().anyMatch(metricNameSet::contains);
}
return false;
}
}