[improvement](chat) QueryTypeParser tag optimize (#433)

This commit is contained in:
lexluo09
2023-11-27 22:58:24 +08:00
committed by GitHub
parent 667272b103
commit 87e222eecc
3 changed files with 32 additions and 34 deletions

View File

@@ -115,6 +115,13 @@ public class SemanticSchema implements Serializable {
return tags; return tags;
} }
public List<SchemaElement> getTags(Set<Long> modelIds) {
List<SchemaElement> tags = new ArrayList<>();
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
.forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getMetrics() { public List<SchemaElement> getMetrics() {
List<SchemaElement> metrics = new ArrayList<>(); List<SchemaElement> metrics = new ArrayList<>();
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics())); modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
@@ -169,6 +176,6 @@ public class SemanticSchema implements Serializable {
return new HashMap<>(); return new HashMap<>();
} }
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema)); -> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
} }
} }

View File

@@ -54,16 +54,16 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.OTHER; return QueryType.OTHER;
} }
//1. entity queryType //1. entity queryType
Set<Long> modelIds = parseInfo.getModel().getModelIds();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) { if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
//If all the fields in the SELECT statement are of tag type. //If all the fields in the SELECT statement are of tag type.
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL()); List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema(); SemanticSchema semanticSchema = semanticService.getSemanticSchema();
if (CollectionUtils.isNotEmpty(selectFields)) { if (CollectionUtils.isNotEmpty(selectFields)) {
Set<String> tags = semanticSchema.getTags().stream().map(SchemaElement::getName) Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
if (tags.containsAll(selectFields)) { if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
return QueryType.TAG; return QueryType.TAG;
} }
} }
@@ -71,7 +71,7 @@ public class QueryTypeParser implements SemanticParser {
//2. metric queryType //2. metric queryType
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL()); List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> metrics = semanticSchema.getMetrics(parseInfo.getModel().getModelIds()); List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
if (CollectionUtils.isNotEmpty(metrics)) { if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet()); Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains); boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);

View File

@@ -14,14 +14,8 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashSet; import java.util.HashSet;
@@ -30,6 +24,10 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
/** /**
* update parse info from correct sql * update parse info from correct sql
@@ -47,6 +45,7 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
.map(SemanticQuery::getParseInfo).collect(Collectors.toList()); .map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(this::updateParseInfo); candidateParses.forEach(this::updateParseInfo);
} }
public void updateParseInfo(SemanticParseInfo parseInfo) { public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo(); SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL(); String correctS2SQL = sqlInfo.getCorrectS2SQL();
@@ -57,7 +56,6 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
if (correctS2SQL.equals(sqlInfo.getS2SQL())) { if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
return; return;
} }
List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL); List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
//set dataInfo //set dataInfo
try { try {
@@ -70,43 +68,35 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
} catch (Exception e) { } catch (Exception e) {
log.error("set dateInfo error :", e); log.error("set dateInfo error :", e);
} }
//set filter //set filter
Set<Long> modelIds = parseInfo.getModel().getModelIds();
try { try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModel().getModelIds()); Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelIds);
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions); List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result); parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) { } catch (Exception e) {
log.error("set dimensionFilter error :", e); log.error("set dimensionFilter error :", e);
} }
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) { if (Objects.isNull(semanticSchema)) {
return; return;
} }
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL())); List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(parseInfo.getModel().getModelIds(), Set<SchemaElement> metrics = getElements(modelIds, allFields, semanticSchema.getMetrics());
allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics); parseInfo.setMetrics(metrics);
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
parseInfo.setQueryType(QueryType.METRIC);
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL()); List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields); List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions( parseInfo.setDimensions(getElements(modelIds, groupByDimensions, semanticSchema.getDimensions()));
getElements(parseInfo.getModel().getModelIds(), groupByDimensions, semanticSchema.getDimensions())); } else if (QueryType.TAG.equals(parseInfo.getQueryType())) {
} else {
parseInfo.setQueryType(QueryType.TAG);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL()); List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields); List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions( parseInfo.setDimensions(getElements(modelIds, selectDimensions, semanticSchema.getDimensions()));
getElements(parseInfo.getModel().getModelIds(), selectDimensions, semanticSchema.getDimensions()));
} }
} }
private Set<SchemaElement> getElements
private Set<SchemaElement> getElements(Set<Long> modelIds, List<String> allFields, List<SchemaElement> elements) { (Set<Long> modelIds, List<String> allFields, List<SchemaElement> elements) {
return elements.stream() return elements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()) .filter(schemaElement -> modelIds.contains(schemaElement.getModel())
&& allFields.contains(schemaElement.getName()) && allFields.contains(schemaElement.getName())
@@ -122,9 +112,8 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement, private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FieldExpression> fieldExpressions) { List<FieldExpression> fieldExpressions) {
List<QueryFilter> result = Lists.newArrayList(); List<QueryFilter> result = Lists.newArrayList();
for (FieldExpression expression : fieldExpressions) { for (FieldExpression expression : fieldExpressions) {
QueryFilter dimensionFilter = new QueryFilter(); QueryFilter dimensionFilter = new QueryFilter();
@@ -181,8 +170,9 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
} }
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator, private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) { FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(
expression.getFieldValue()));
} }
private boolean hasSecondDate(List<FieldExpression> dateExpressions) { private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
@@ -210,7 +200,8 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
} }
return result.stream(); return result.stream();
}) })
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2)); .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(),
(value1, value2) -> value2));
} }
} }