(improvement)(chat) metric check compatible with count(*) (#416)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-11-22 18:29:58 +08:00
committed by GitHub
parent 42c0bea8fc
commit c168925f03
4 changed files with 46 additions and 5 deletions

View File

@@ -18,7 +18,9 @@ import com.tencent.supersonic.knowledge.service.SchemaService;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -32,6 +34,7 @@ public class MetricCheckPostProcessor implements PostProcessor {
@Override
public void process(QueryContext queryContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
Map<Long, ModelSchema> modelSchemaMap = new HashMap<>();
for (SemanticQuery semanticQuery : semanticQueries) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (!QueryType.METRIC.equals(parseInfo.getQueryType())) {
@@ -39,8 +42,9 @@ public class MetricCheckPostProcessor implements PostProcessor {
}
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId());
String correctSqlProcessed = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema);
parseInfo.getSqlInfo().setCorrectS2SQL(correctSqlProcessed);
String processedSql = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema);
parseInfo.getSqlInfo().setCorrectS2SQL(processedSql);
modelSchemaMap.put(modelSchema.getModel().getModel(), modelSchema);
}
semanticQueries.removeIf(semanticQuery -> {
if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) {
@@ -50,13 +54,14 @@ public class MetricCheckPostProcessor implements PostProcessor {
if (StringUtils.isBlank(correctSql)) {
return false;
}
return CollectionUtils.isEmpty(SqlParserSelectHelper.getAggregateFields(correctSql));
return !checkHasMetric(correctSql, modelSchemaMap.get(semanticQuery.getParseInfo().getModelId()));
});
}
public String processCorrectSql(String correctSql, ModelSchema modelSchema) {
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql);
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql);
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql)
.stream().filter(metricField -> !metricField.equals("*")).collect(Collectors.toList());
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctSql);
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) {
@@ -191,6 +196,19 @@ public class MetricCheckPostProcessor implements PostProcessor {
return schemaElement != null;
}
private boolean checkHasMetric(String correctSql, ModelSchema modelSchema) {
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctSql);
List<String> aggFields = SqlParserSelectHelper.getAggregateFields(correctSql);
List<String> collect = modelSchema.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
for (String field : selectFields) {
if (collect.contains(field)) {
return true;
}
}
return !CollectionUtils.isEmpty(aggFields);
}
private static String removeFieldInSql(String sql, Set<String> metricToRemove,
Set<String> dimensionByToRemove, Set<String> whereFieldsToRemove) {
sql = SqlParserRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove);

View File

@@ -33,7 +33,8 @@ public class SimilarMetricExecuteResponder implements ExecuteResponder {
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
if (!parseInfo.getQueryType().equals(QueryType.METRIC)
&& parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE) {
|| parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
return;
}
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());

View File

@@ -161,6 +161,7 @@ public class QueryServiceImpl implements QueryService {
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(postProcessor.getClass().getSimpleName())
.type(CostType.POSTPROCESSOR.getType()).build());
log.info("{} result:{}", postProcessor.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
//6. responder
parseResponders.forEach(parseResponder -> {
@@ -169,6 +170,7 @@ public class QueryServiceImpl implements QueryService {
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(parseResponder.getClass().getSimpleName())
.type(CostType.PARSERRESPONDER.getType()).build());
log.info("{} result:{}", parseResponder.getClass().getSimpleName(), JsonUtil.toString(parseResult));
});
if (Objects.nonNull(parseResult.getQueryId()) && timeCostDOList.size() > 0) {

View File

@@ -78,6 +78,26 @@ class MetricCheckPostProcessorTest {
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 访问次数 from 超音数";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 访问次数 from 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
/**
* 访问次数 drill down dimension is 用户名 and 部门
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions