From c168925f0321efbeef1526d79beb443ab81a2bc3 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Wed, 22 Nov 2023 18:29:58 +0800 Subject: [PATCH] (improvement)(chat) metric check compatible with count(*) (#416) Co-authored-by: jolunoluo --- .../MetricCheckPostProcessor.java | 26 ++++++++++++++++--- .../SimilarMetricExecuteResponder.java | 3 ++- .../chat/service/impl/QueryServiceImpl.java | 2 ++ .../MetricCheckPostProcessorTest.java | 20 ++++++++++++++ 4 files changed, 46 insertions(+), 5 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessor.java b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessor.java index f5664da85..b89d2fbcf 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessor.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessor.java @@ -18,7 +18,9 @@ import com.tencent.supersonic.knowledge.service.SchemaService; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -32,6 +34,7 @@ public class MetricCheckPostProcessor implements PostProcessor { @Override public void process(QueryContext queryContext) { List semanticQueries = queryContext.getCandidateQueries(); + Map modelSchemaMap = new HashMap<>(); for (SemanticQuery semanticQuery : semanticQueries) { SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); if (!QueryType.METRIC.equals(parseInfo.getQueryType())) { @@ -39,8 +42,9 @@ public class MetricCheckPostProcessor implements PostProcessor { } SchemaService schemaService = ContextUtils.getBean(SchemaService.class); ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId()); - String correctSqlProcessed = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema); - parseInfo.getSqlInfo().setCorrectS2SQL(correctSqlProcessed); + String processedSql = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema); + parseInfo.getSqlInfo().setCorrectS2SQL(processedSql); + modelSchemaMap.put(modelSchema.getModel().getModel(), modelSchema); } semanticQueries.removeIf(semanticQuery -> { if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) { @@ -50,13 +54,14 @@ public class MetricCheckPostProcessor implements PostProcessor { if (StringUtils.isBlank(correctSql)) { return false; } - return CollectionUtils.isEmpty(SqlParserSelectHelper.getAggregateFields(correctSql)); + return !checkHasMetric(correctSql, modelSchemaMap.get(semanticQuery.getParseInfo().getModelId())); }); } public String processCorrectSql(String correctSql, ModelSchema modelSchema) { List groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql); - List metricFields = SqlParserSelectHelper.getAggregateFields(correctSql); + List metricFields = SqlParserSelectHelper.getAggregateFields(correctSql) + .stream().filter(metricField -> !metricField.equals("*")).collect(Collectors.toList()); List whereFields = SqlParserSelectHelper.getWhereFields(correctSql); List dimensionFields = getDimensionFields(groupByFields, whereFields); if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) { @@ -191,6 +196,19 @@ public class MetricCheckPostProcessor implements PostProcessor { return schemaElement != null; } + private boolean checkHasMetric(String correctSql, ModelSchema modelSchema) { + List selectFields = SqlParserSelectHelper.getSelectFields(correctSql); + List aggFields = SqlParserSelectHelper.getAggregateFields(correctSql); + List collect = modelSchema.getMetrics().stream() + .map(SchemaElement::getName).collect(Collectors.toList()); + for (String field : selectFields) { + if (collect.contains(field)) { + return true; + } + } + return !CollectionUtils.isEmpty(aggFields); + } + private static String removeFieldInSql(String sql, Set metricToRemove, Set dimensionByToRemove, Set whereFieldsToRemove) { sql = SqlParserRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/SimilarMetricExecuteResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/SimilarMetricExecuteResponder.java index 46077231f..e83c95f70 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/SimilarMetricExecuteResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/SimilarMetricExecuteResponder.java @@ -33,7 +33,8 @@ public class SimilarMetricExecuteResponder implements ExecuteResponder { private void fillSimilarMetric(SemanticParseInfo parseInfo) { if (!parseInfo.getQueryType().equals(QueryType.METRIC) - && parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE) { + || parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE + || CollectionUtils.isEmpty(parseInfo.getMetrics())) { return; } List metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index a0a16d9e4..0c89ccc5b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -161,6 +161,7 @@ public class QueryServiceImpl implements QueryService { timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) .interfaceName(postProcessor.getClass().getSimpleName()) .type(CostType.POSTPROCESSOR.getType()).build()); + log.info("{} result:{}", postProcessor.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); }); //6. responder parseResponders.forEach(parseResponder -> { @@ -169,6 +170,7 @@ public class QueryServiceImpl implements QueryService { timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) .interfaceName(parseResponder.getClass().getSimpleName()) .type(CostType.PARSERRESPONDER.getType()).build()); + log.info("{} result:{}", parseResponder.getClass().getSimpleName(), JsonUtil.toString(parseResult)); }); if (Objects.nonNull(parseResult.getQueryId()) && timeCostDOList.size() > 0) { diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessorTest.java index 35c979858..65235316d 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessorTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/postprocessor/MetricCheckPostProcessorTest.java @@ -78,6 +78,26 @@ class MetricCheckPostProcessorTest { Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); } + @Test + void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 访问次数 from 超音数"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, + mockModelSchemaNoDimensionSetting()); + String expectedProcessedSql = "select 访问次数 from 超音数"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + + @Test + void testProcessCorrectSql_noDrillDownDimensionSetting_count() { + MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor(); + String correctSql = "select 部门, count(*) from 超音数 group by 部门"; + String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, + mockModelSchemaNoDimensionSetting()); + String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门"; + Assertions.assertEquals(expectedProcessedSql, actualProcessedSql); + } + /** * 访问次数 drill down dimension is 用户名 and 部门 * 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions