mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[improvement](chat) QueryTypeParser tag optimize (#433)
This commit is contained in:
@@ -115,6 +115,13 @@ public class SemanticSchema implements Serializable {
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics() {
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||
@@ -169,6 +176,6 @@ public class SemanticSchema implements Serializable {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,16 +54,16 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.OTHER;
|
||||
}
|
||||
//1. entity queryType
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
|
||||
if (CollectionUtils.isNotEmpty(selectFields)) {
|
||||
Set<String> tags = semanticSchema.getTags().stream().map(SchemaElement::getName)
|
||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (tags.containsAll(selectFields)) {
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
|
||||
return QueryType.TAG;
|
||||
}
|
||||
}
|
||||
@@ -71,7 +71,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(parseInfo.getModel().getModelIds());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
|
||||
@@ -14,14 +14,8 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
@@ -30,6 +24,10 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* update parse info from correct sql
|
||||
@@ -47,6 +45,7 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
candidateParses.forEach(this::updateParseInfo);
|
||||
}
|
||||
|
||||
public void updateParseInfo(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
@@ -57,7 +56,6 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
|
||||
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
|
||||
//set dataInfo
|
||||
try {
|
||||
@@ -70,43 +68,35 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
|
||||
} catch (Exception e) {
|
||||
log.error("set dateInfo error :", e);
|
||||
}
|
||||
|
||||
//set filter
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
try {
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModel().getModelIds());
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelIds);
|
||||
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
||||
parseInfo.getDimensionFilters().addAll(result);
|
||||
} catch (Exception e) {
|
||||
log.error("set dimensionFilter error :", e);
|
||||
}
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
Set<SchemaElement> metrics = getElements(parseInfo.getModel().getModelIds(),
|
||||
allFields, semanticSchema.getMetrics());
|
||||
Set<SchemaElement> metrics = getElements(modelIds, allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
|
||||
parseInfo.setQueryType(QueryType.METRIC);
|
||||
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModel().getModelIds(), groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else {
|
||||
parseInfo.setQueryType(QueryType.TAG);
|
||||
parseInfo.setDimensions(getElements(modelIds, groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else if (QueryType.TAG.equals(parseInfo.getQueryType())) {
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModel().getModelIds(), selectDimensions, semanticSchema.getDimensions()));
|
||||
parseInfo.setDimensions(getElements(modelIds, selectDimensions, semanticSchema.getDimensions()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private Set<SchemaElement> getElements(Set<Long> modelIds, List<String> allFields, List<SchemaElement> elements) {
|
||||
private Set<SchemaElement> getElements
|
||||
(Set<Long> modelIds, List<String> allFields, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel())
|
||||
&& allFields.contains(schemaElement.getName())
|
||||
@@ -122,9 +112,8 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FieldExpression> fieldExpressions) {
|
||||
List<FieldExpression> fieldExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FieldExpression expression : fieldExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
@@ -181,8 +170,9 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
|
||||
}
|
||||
|
||||
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(
|
||||
expression.getFieldValue()));
|
||||
}
|
||||
|
||||
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
|
||||
@@ -210,7 +200,8 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
|
||||
}
|
||||
return result.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
|
||||
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(),
|
||||
(value1, value2) -> value2));
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user