diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java index fbf4c5a7c..b7ec83076 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java @@ -115,6 +115,13 @@ public class SemanticSchema implements Serializable { return tags; } + public List getTags(Set modelIds) { + List tags = new ArrayList<>(); + modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel())) + .forEach(d -> tags.addAll(d.getTags())); + return tags; + } + public List getMetrics() { List metrics = new ArrayList<>(); modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics())); @@ -169,6 +176,6 @@ public class SemanticSchema implements Serializable { return new HashMap<>(); } return modelSchemaList.stream().collect(Collectors.toMap(modelSchema - -> modelSchema.getModel().getModel(), modelSchema -> modelSchema)); + -> modelSchema.getModel().getModel(), modelSchema -> modelSchema)); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java index ee747a3f6..852dd0b67 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java @@ -54,16 +54,16 @@ public class QueryTypeParser implements SemanticParser { return QueryType.OTHER; } //1. entity queryType + Set modelIds = parseInfo.getModel().getModelIds(); if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) { //If all the fields in the SELECT statement are of tag type. List selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL()); SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticSchema semanticSchema = semanticService.getSemanticSchema(); - if (CollectionUtils.isNotEmpty(selectFields)) { - Set tags = semanticSchema.getTags().stream().map(SchemaElement::getName) + Set tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName) .collect(Collectors.toSet()); - if (tags.containsAll(selectFields)) { + if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) { return QueryType.TAG; } } @@ -71,7 +71,7 @@ public class QueryTypeParser implements SemanticParser { //2. metric queryType List selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL()); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - List metrics = semanticSchema.getMetrics(parseInfo.getModel().getModelIds()); + List metrics = semanticSchema.getMetrics(modelIds); if (CollectionUtils.isNotEmpty(metrics)) { Set metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet()); boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/ParseInfoUpdateProcessor.java b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/ParseInfoUpdateProcessor.java index a99a414d0..e5f80d2f4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/ParseInfoUpdateProcessor.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/ParseInfoUpdateProcessor.java @@ -14,14 +14,8 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.springframework.util.CollectionUtils; - import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -30,6 +24,10 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.util.CollectionUtils; /** * update parse info from correct sql @@ -47,6 +45,7 @@ public class ParseInfoUpdateProcessor implements PostProcessor { .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); candidateParses.forEach(this::updateParseInfo); } + public void updateParseInfo(SemanticParseInfo parseInfo) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); String correctS2SQL = sqlInfo.getCorrectS2SQL(); @@ -57,7 +56,6 @@ public class ParseInfoUpdateProcessor implements PostProcessor { if (correctS2SQL.equals(sqlInfo.getS2SQL())) { return; } - List expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL); //set dataInfo try { @@ -70,43 +68,35 @@ public class ParseInfoUpdateProcessor implements PostProcessor { } catch (Exception e) { log.error("set dateInfo error :", e); } - //set filter + Set modelIds = parseInfo.getModel().getModelIds(); try { - Map fieldNameToElement = getNameToElement(parseInfo.getModel().getModelIds()); + Map fieldNameToElement = getNameToElement(modelIds); List result = getDimensionFilter(fieldNameToElement, expressions); parseInfo.getDimensionFilters().addAll(result); } catch (Exception e) { log.error("set dimensionFilter error :", e); } - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - if (Objects.isNull(semanticSchema)) { return; } List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL())); - Set metrics = getElements(parseInfo.getModel().getModelIds(), - allFields, semanticSchema.getMetrics()); + Set metrics = getElements(modelIds, allFields, semanticSchema.getMetrics()); parseInfo.setMetrics(metrics); - - if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) { - parseInfo.setQueryType(QueryType.METRIC); + if (QueryType.METRIC.equals(parseInfo.getQueryType())) { List groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL()); List groupByDimensions = getFieldsExceptDate(groupByFields); - parseInfo.setDimensions( - getElements(parseInfo.getModel().getModelIds(), groupByDimensions, semanticSchema.getDimensions())); - } else { - parseInfo.setQueryType(QueryType.TAG); + parseInfo.setDimensions(getElements(modelIds, groupByDimensions, semanticSchema.getDimensions())); + } else if (QueryType.TAG.equals(parseInfo.getQueryType())) { List selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL()); List selectDimensions = getFieldsExceptDate(selectFields); - parseInfo.setDimensions( - getElements(parseInfo.getModel().getModelIds(), selectDimensions, semanticSchema.getDimensions())); + parseInfo.setDimensions(getElements(modelIds, selectDimensions, semanticSchema.getDimensions())); } } - - private Set getElements(Set modelIds, List allFields, List elements) { + private Set getElements + (Set modelIds, List allFields, List elements) { return elements.stream() .filter(schemaElement -> modelIds.contains(schemaElement.getModel()) && allFields.contains(schemaElement.getName()) @@ -122,9 +112,8 @@ public class ParseInfoUpdateProcessor implements PostProcessor { .collect(Collectors.toList()); } - private List getDimensionFilter(Map fieldNameToElement, - List fieldExpressions) { + List fieldExpressions) { List result = Lists.newArrayList(); for (FieldExpression expression : fieldExpressions) { QueryFilter dimensionFilter = new QueryFilter(); @@ -181,8 +170,9 @@ public class ParseInfoUpdateProcessor implements PostProcessor { } private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator, - FilterOperatorEnum... operatorEnums) { - return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); + FilterOperatorEnum... operatorEnums) { + return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull( + expression.getFieldValue())); } private boolean hasSecondDate(List dateExpressions) { @@ -210,7 +200,8 @@ public class ParseInfoUpdateProcessor implements PostProcessor { } return result.stream(); }) - .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2)); + .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), + (value1, value2) -> value2)); } } \ No newline at end of file