mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) metric check compatible with count(*) (#416)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -18,7 +18,9 @@ import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -32,6 +34,7 @@ public class MetricCheckPostProcessor implements PostProcessor {
|
||||
@Override
|
||||
public void process(QueryContext queryContext) {
|
||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
||||
Map<Long, ModelSchema> modelSchemaMap = new HashMap<>();
|
||||
for (SemanticQuery semanticQuery : semanticQueries) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
if (!QueryType.METRIC.equals(parseInfo.getQueryType())) {
|
||||
@@ -39,8 +42,9 @@ public class MetricCheckPostProcessor implements PostProcessor {
|
||||
}
|
||||
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId());
|
||||
String correctSqlProcessed = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema);
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctSqlProcessed);
|
||||
String processedSql = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema);
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(processedSql);
|
||||
modelSchemaMap.put(modelSchema.getModel().getModel(), modelSchema);
|
||||
}
|
||||
semanticQueries.removeIf(semanticQuery -> {
|
||||
if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) {
|
||||
@@ -50,13 +54,14 @@ public class MetricCheckPostProcessor implements PostProcessor {
|
||||
if (StringUtils.isBlank(correctSql)) {
|
||||
return false;
|
||||
}
|
||||
return CollectionUtils.isEmpty(SqlParserSelectHelper.getAggregateFields(correctSql));
|
||||
return !checkHasMetric(correctSql, modelSchemaMap.get(semanticQuery.getParseInfo().getModelId()));
|
||||
});
|
||||
}
|
||||
|
||||
public String processCorrectSql(String correctSql, ModelSchema modelSchema) {
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql);
|
||||
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql);
|
||||
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql)
|
||||
.stream().filter(metricField -> !metricField.equals("*")).collect(Collectors.toList());
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctSql);
|
||||
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
|
||||
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) {
|
||||
@@ -191,6 +196,19 @@ public class MetricCheckPostProcessor implements PostProcessor {
|
||||
return schemaElement != null;
|
||||
}
|
||||
|
||||
private boolean checkHasMetric(String correctSql, ModelSchema modelSchema) {
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctSql);
|
||||
List<String> aggFields = SqlParserSelectHelper.getAggregateFields(correctSql);
|
||||
List<String> collect = modelSchema.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
for (String field : selectFields) {
|
||||
if (collect.contains(field)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return !CollectionUtils.isEmpty(aggFields);
|
||||
}
|
||||
|
||||
private static String removeFieldInSql(String sql, Set<String> metricToRemove,
|
||||
Set<String> dimensionByToRemove, Set<String> whereFieldsToRemove) {
|
||||
sql = SqlParserRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove);
|
||||
|
||||
@@ -33,7 +33,8 @@ public class SimilarMetricExecuteResponder implements ExecuteResponder {
|
||||
|
||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||
if (!parseInfo.getQueryType().equals(QueryType.METRIC)
|
||||
&& parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE) {
|
||||
|| parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE
|
||||
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
|
||||
return;
|
||||
}
|
||||
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||
|
||||
@@ -161,6 +161,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
||||
.interfaceName(postProcessor.getClass().getSimpleName())
|
||||
.type(CostType.POSTPROCESSOR.getType()).build());
|
||||
log.info("{} result:{}", postProcessor.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
||||
});
|
||||
//6. responder
|
||||
parseResponders.forEach(parseResponder -> {
|
||||
@@ -169,6 +170,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
||||
.interfaceName(parseResponder.getClass().getSimpleName())
|
||||
.type(CostType.PARSERRESPONDER.getType()).build());
|
||||
log.info("{} result:{}", parseResponder.getClass().getSimpleName(), JsonUtil.toString(parseResult));
|
||||
});
|
||||
|
||||
if (Objects.nonNull(parseResult.getQueryId()) && timeCostDOList.size() > 0) {
|
||||
|
||||
@@ -78,6 +78,26 @@ class MetricCheckPostProcessorTest {
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 访问次数 from 超音数";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
|
||||
mockModelSchemaNoDimensionSetting());
|
||||
String expectedProcessedSql = "select 访问次数 from 超音数";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
|
||||
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
|
||||
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
|
||||
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
|
||||
mockModelSchemaNoDimensionSetting());
|
||||
String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门";
|
||||
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
|
||||
}
|
||||
|
||||
/**
|
||||
* 访问次数 drill down dimension is 用户名 and 部门
|
||||
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
|
||||
|
||||
Reference in New Issue
Block a user