mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(improvement)(headless) If all the fields in the select/where statement are of tag type. (#886)
This commit is contained in:
@@ -13,10 +13,12 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
|
|||||||
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -52,35 +54,51 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
Long dataSetId = parseInfo.getDataSetId();
|
Long dataSetId = parseInfo.getDataSetId();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||||
//If all the fields in the SELECT statement are of tag type.
|
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
|
||||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
|
||||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
if (CollectionUtils.isNotEmpty(ids) && ids.stream()
|
||||||
|
.anyMatch(whereFilterByTimeFields::contains)) {
|
||||||
return QueryType.ID;
|
return QueryType.ID;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
|
selectFields.addAll(whereFields);
|
||||||
|
List<String> selectWhereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||||
|
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
||||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
//If all the fields in the SELECT/WHERE statement are of tag type.
|
||||||
|
if (CollectionUtils.isNotEmpty(tags)
|
||||||
|
&& tags.containsAll(selectWhereFilterByTimeFields)) {
|
||||||
return QueryType.TAG;
|
return QueryType.TAG;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//2. metric queryType
|
//2. metric queryType
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
if (selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) {
|
||||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
|
||||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
|
||||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
|
||||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
|
||||||
if (containMetric) {
|
|
||||||
return QueryType.METRIC;
|
return QueryType.METRIC;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return QueryType.ID;
|
return QueryType.ID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||||
|
List<String> selectAndWhereFilterByTimeFields = whereFields
|
||||||
|
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
return selectAndWhereFilterByTimeFields;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
|
||||||
|
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
|
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||||
|
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||||
|
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||||
|
return selectFields.stream().anyMatch(metricNameSet::contains);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user