mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
(improvement)(chat) metric check compatible with count(*) (#416)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -18,7 +18,9 @@ import com.tencent.supersonic.knowledge.service.SchemaService;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
@@ -32,6 +34,7 @@ public class MetricCheckPostProcessor implements PostProcessor {
|
|||||||
@Override
|
@Override
|
||||||
public void process(QueryContext queryContext) {
|
public void process(QueryContext queryContext) {
|
||||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
||||||
|
Map<Long, ModelSchema> modelSchemaMap = new HashMap<>();
|
||||||
for (SemanticQuery semanticQuery : semanticQueries) {
|
for (SemanticQuery semanticQuery : semanticQueries) {
|
||||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
if (!QueryType.METRIC.equals(parseInfo.getQueryType())) {
|
if (!QueryType.METRIC.equals(parseInfo.getQueryType())) {
|
||||||
@@ -39,8 +42,9 @@ public class MetricCheckPostProcessor implements PostProcessor {
|
|||||||
}
|
}
|
||||||
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
||||||
ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId());
|
ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId());
|
||||||
String correctSqlProcessed = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema);
|
String processedSql = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema);
|
||||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctSqlProcessed);
|
parseInfo.getSqlInfo().setCorrectS2SQL(processedSql);
|
||||||
|
modelSchemaMap.put(modelSchema.getModel().getModel(), modelSchema);
|
||||||
}
|
}
|
||||||
semanticQueries.removeIf(semanticQuery -> {
|
semanticQueries.removeIf(semanticQuery -> {
|
||||||
if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) {
|
if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) {
|
||||||
@@ -50,13 +54,14 @@ public class MetricCheckPostProcessor implements PostProcessor {
|
|||||||
if (StringUtils.isBlank(correctSql)) {
|
if (StringUtils.isBlank(correctSql)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return CollectionUtils.isEmpty(SqlParserSelectHelper.getAggregateFields(correctSql));
|
return !checkHasMetric(correctSql, modelSchemaMap.get(semanticQuery.getParseInfo().getModelId()));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
public String processCorrectSql(String correctSql, ModelSchema modelSchema) {
|
public String processCorrectSql(String correctSql, ModelSchema modelSchema) {
|
||||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql);
|
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql);
|
||||||
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql);
|
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql)
|
||||||
|
.stream().filter(metricField -> !metricField.equals("*")).collect(Collectors.toList());
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctSql);
|
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctSql);
|
||||||
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
|
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
|
||||||
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) {
|
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) {
|
||||||
@@ -191,6 +196,19 @@ public class MetricCheckPostProcessor implements PostProcessor {
|
|||||||
return schemaElement != null;
|
return schemaElement != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean checkHasMetric(String correctSql, ModelSchema modelSchema) {
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctSql);
|
||||||
|
List<String> aggFields = SqlParserSelectHelper.getAggregateFields(correctSql);
|
||||||
|
List<String> collect = modelSchema.getMetrics().stream()
|
||||||
|
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||||
|
for (String field : selectFields) {
|
||||||
|
if (collect.contains(field)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return !CollectionUtils.isEmpty(aggFields);
|
||||||
|
}
|
||||||
|
|
||||||
private static String removeFieldInSql(String sql, Set<String> metricToRemove,
|
private static String removeFieldInSql(String sql, Set<String> metricToRemove,
|
||||||
Set<String> dimensionByToRemove, Set<String> whereFieldsToRemove) {
|
Set<String> dimensionByToRemove, Set<String> whereFieldsToRemove) {
|
||||||
sql = SqlParserRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove);
|
sql = SqlParserRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove);
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ public class SimilarMetricExecuteResponder implements ExecuteResponder {
|
|||||||
|
|
||||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||||
if (!parseInfo.getQueryType().equals(QueryType.METRIC)
|
if (!parseInfo.getQueryType().equals(QueryType.METRIC)
|
||||||
&& parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE) {
|
|| parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE
|
||||||
|
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
||||||
.interfaceName(postProcessor.getClass().getSimpleName())
|
.interfaceName(postProcessor.getClass().getSimpleName())
|
||||||
.type(CostType.POSTPROCESSOR.getType()).build());
|
.type(CostType.POSTPROCESSOR.getType()).build());
|
||||||
|
log.info("{} result:{}", postProcessor.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
||||||
});
|
});
|
||||||
//6. responder
|
//6. responder
|
||||||
parseResponders.forEach(parseResponder -> {
|
parseResponders.forEach(parseResponder -> {
|
||||||
@@ -169,6 +170,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
||||||
.interfaceName(parseResponder.getClass().getSimpleName())
|
.interfaceName(parseResponder.getClass().getSimpleName())
|
||||||
.type(CostType.PARSERRESPONDER.getType()).build());
|
.type(CostType.PARSERRESPONDER.getType()).build());
|
||||||
|
log.info("{} result:{}", parseResponder.getClass().getSimpleName(), JsonUtil.toString(parseResult));
|
||||||
});
|
});
|
||||||
|
|
||||||
if (Objects.nonNull(parseResult.getQueryId()) && timeCostDOList.size() > 0) {
|
if (Objects.nonNull(parseResult.getQueryId()) && timeCostDOList.size() > 0) {
|
||||||
|
|||||||
@@ -78,6 +78,26 @@ class MetricCheckPostProcessorTest {
|
|||||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
|
||||||
|
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||||
|
String correctSql = "select 访问次数 from 超音数";
|
||||||
|
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
|
||||||
|
mockModelSchemaNoDimensionSetting());
|
||||||
|
String expectedProcessedSql = "select 访问次数 from 超音数";
|
||||||
|
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
|
||||||
|
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||||
|
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
|
||||||
|
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
|
||||||
|
mockModelSchemaNoDimensionSetting());
|
||||||
|
String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门";
|
||||||
|
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 访问次数 drill down dimension is 用户名 and 部门
|
* 访问次数 drill down dimension is 用户名 and 部门
|
||||||
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
|
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
|
||||||
|
|||||||
Reference in New Issue
Block a user