[improvement](chat) QueryTypeParser tag optimize (#433)

This commit is contained in:
lexluo09
2023-11-27 22:58:24 +08:00
committed by GitHub
parent 667272b103
commit 87e222eecc
3 changed files with 32 additions and 34 deletions

View File

@@ -115,6 +115,13 @@ public class SemanticSchema implements Serializable {
return tags;
}
public List<SchemaElement> getTags(Set<Long> modelIds) {
List<SchemaElement> tags = new ArrayList<>();
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
.forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getMetrics() {
List<SchemaElement> metrics = new ArrayList<>();
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
@@ -169,6 +176,6 @@ public class SemanticSchema implements Serializable {
return new HashMap<>();
}
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
}
}

View File

@@ -54,16 +54,16 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.OTHER;
}
//1. entity queryType
Set<Long> modelIds = parseInfo.getModel().getModelIds();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
//If all the fields in the SELECT statement are of tag type.
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
if (CollectionUtils.isNotEmpty(selectFields)) {
Set<String> tags = semanticSchema.getTags().stream().map(SchemaElement::getName)
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (tags.containsAll(selectFields)) {
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
return QueryType.TAG;
}
}
@@ -71,7 +71,7 @@ public class QueryTypeParser implements SemanticParser {
//2. metric queryType
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> metrics = semanticSchema.getMetrics(parseInfo.getModel().getModelIds());
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);

View File

@@ -14,14 +14,8 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
@@ -30,6 +24,10 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
/**
* update parse info from correct sql
@@ -47,6 +45,7 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(this::updateParseInfo);
}
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
@@ -57,7 +56,6 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
return;
}
List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
//set dataInfo
try {
@@ -70,43 +68,35 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
Set<Long> modelIds = parseInfo.getModel().getModelIds();
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModel().getModelIds());
Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelIds);
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(parseInfo.getModel().getModelIds(),
allFields, semanticSchema.getMetrics());
Set<SchemaElement> metrics = getElements(modelIds, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
parseInfo.setQueryType(QueryType.METRIC);
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(
getElements(parseInfo.getModel().getModelIds(), groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setQueryType(QueryType.TAG);
parseInfo.setDimensions(getElements(modelIds, groupByDimensions, semanticSchema.getDimensions()));
} else if (QueryType.TAG.equals(parseInfo.getQueryType())) {
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(
getElements(parseInfo.getModel().getModelIds(), selectDimensions, semanticSchema.getDimensions()));
parseInfo.setDimensions(getElements(modelIds, selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Set<Long> modelIds, List<String> allFields, List<SchemaElement> elements) {
private Set<SchemaElement> getElements
(Set<Long> modelIds, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel())
&& allFields.contains(schemaElement.getName())
@@ -122,9 +112,8 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
.collect(Collectors.toList());
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FieldExpression> fieldExpressions) {
List<FieldExpression> fieldExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FieldExpression expression : fieldExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
@@ -181,8 +170,9 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
}
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(
expression.getFieldValue()));
}
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
@@ -210,7 +200,8 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(),
(value1, value2) -> value2));
}
}