(improvement)(headless) Remove MetricCheckProcessor in chat and MetricDrillDownChecker in headless (#716)

(improvement)(headless) remove MetricCheckProcessor in chat and MetricDrillDownChecker in headless

---------

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-02-04 14:28:24 +08:00
committed by GitHub
parent 4d4922d269
commit 0c4c6d83ef
18 changed files with 400 additions and 417 deletions

View File

@@ -33,9 +33,9 @@ public abstract class TagListQuery extends TagSemanticQuery {
Set<SchemaElement> metrics = new LinkedHashSet<>();
Set<Order> orders = new LinkedHashSet<>();
TagTypeDefaultConfig tagTypeDefaultConfig = viewSchema.getTagTypeDefaultConfig();
if (tagTypeDefaultConfig != null) {
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getMetricIds())) {
metrics = tagTypeDefaultConfig.getMetricIds()
if (tagTypeDefaultConfig != null && tagTypeDefaultConfig.getDefaultDisplayInfo() != null) {
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
metrics = tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds()
.stream().map(id -> {
SchemaElement metric = viewSchema.getElement(SchemaElementType.METRIC, id);
if (metric != null) {
@@ -44,8 +44,8 @@ public abstract class TagListQuery extends TagSemanticQuery {
return metric;
}).filter(Objects::nonNull).collect(Collectors.toSet());
}
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDimensionIds())) {
dimensions = tagTypeDefaultConfig.getDimensionIds().stream()
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
dimensions = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> viewSchema.getElement(SchemaElementType.DIMENSION, id))
.filter(Objects::nonNull).collect(Collectors.toSet());
}

View File

@@ -1,220 +0,0 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.server.service.SemanticService;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
/**
* MetricCheckProcessor verifies whether the dimensions
* involved in the query in metric mode can drill down on the metric.
*/
@Slf4j
public class MetricCheckProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
for (SemanticQuery semanticQuery : semanticQueries) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (!QueryType.METRIC.equals(parseInfo.getQueryType())) {
continue;
}
String correctSqlProcessed = processCorrectSql(parseInfo, semanticSchema);
log.info("correct sql:{}", correctSqlProcessed);
parseInfo.getSqlInfo().setCorrectS2SQL(correctSqlProcessed);
}
semanticQueries.removeIf(semanticQuery -> {
if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) {
return false;
}
String correctSql = semanticQuery.getParseInfo().getSqlInfo().getCorrectS2SQL();
if (StringUtils.isBlank(correctSql)) {
return false;
}
return !checkHasMetric(correctSql, semanticSchema);
});
}
public String processCorrectSql(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
String correctSql = parseInfo.getSqlInfo().getCorrectS2SQL();
List<String> groupByFields = SqlSelectHelper.getGroupByFields(correctSql);
List<String> metricFields = SqlSelectHelper.getAggregateFields(correctSql);
List<String> whereFields = SqlSelectHelper.getWhereFields(correctSql);
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) {
return correctSql;
}
Set<String> metricToRemove = Sets.newHashSet();
Set<String> groupByToRemove = Sets.newHashSet();
Set<String> whereFieldsToRemove = Sets.newHashSet();
for (String metricName : metricFields) {
SchemaElement metricElement = semanticSchema.getElementByName(SchemaElementType.METRIC, metricName);
if (metricElement == null) {
metricToRemove.add(metricName);
}
if (!checkNecessaryDimension(metricElement, semanticSchema, dimensionFields)) {
metricToRemove.add(metricName);
}
}
for (String dimensionName : whereFields) {
if (TimeDimensionEnum.containsTimeDimension(dimensionName)) {
continue;
}
if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, semanticSchema)) {
whereFieldsToRemove.add(dimensionName);
}
if (!checkDrillDownDimension(dimensionName, metricFields, semanticSchema)) {
whereFieldsToRemove.add(dimensionName);
}
}
for (String dimensionName : groupByFields) {
if (TimeDimensionEnum.containsTimeDimension(dimensionName)) {
continue;
}
if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, semanticSchema)) {
groupByToRemove.add(dimensionName);
}
if (!checkDrillDownDimension(dimensionName, metricFields, semanticSchema)) {
groupByToRemove.add(dimensionName);
}
}
return removeFieldInSql(correctSql, metricToRemove, groupByToRemove, whereFieldsToRemove);
}
/**
* To check whether the dimension bound to the metric exists,
* eg: metric like UV is calculated in a certain dimension, it cannot be used on other dimensions.
*/
private boolean checkNecessaryDimension(SchemaElement metric, SemanticSchema semanticSchema,
List<String> dimensionFields) {
List<String> necessaryDimensions = getNecessaryDimensionNames(metric, semanticSchema);
if (CollectionUtils.isEmpty(necessaryDimensions)) {
return true;
}
for (String dimension : necessaryDimensions) {
if (!dimensionFields.contains(dimension)) {
return false;
}
}
return true;
}
/**
* To check whether the dimension can drill down the metric,
* eg: some descriptive dimensions are not suitable as drill-down dimensions
*/
private boolean checkDrillDownDimension(String dimensionName, List<String> metrics,
SemanticSchema semanticSchema) {
List<SchemaElement> metricElements = semanticSchema.getMetrics().stream()
.filter(schemaElement -> metrics.contains(schemaElement.getName()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(metricElements)) {
return false;
}
List<String> relateDimensions = metricElements.stream()
.filter(schemaElement -> !CollectionUtils.isEmpty(schemaElement.getRelatedSchemaElements()))
.map(schemaElement -> schemaElement.getRelatedSchemaElements().stream()
.map(RelatedSchemaElement::getDimensionId).collect(Collectors.toList()))
.flatMap(Collection::stream)
.map(id -> convertDimensionIdToName(id, semanticSchema))
.filter(Objects::nonNull)
.collect(Collectors.toList());
//if no metric has drill down dimension, return true
if (CollectionUtils.isEmpty(relateDimensions)) {
return true;
}
//if this dimension not in relate drill-down dimensions, return false
return relateDimensions.contains(dimensionName);
}
private List<String> getNecessaryDimensionNames(SchemaElement metric, SemanticSchema semanticSchema) {
List<Long> necessaryDimensionIds = getNecessaryDimensions(metric);
return necessaryDimensionIds.stream().map(id -> convertDimensionIdToName(id, semanticSchema))
.filter(Objects::nonNull).collect(Collectors.toList());
}
private List<Long> getNecessaryDimensions(SchemaElement metric) {
if (metric == null) {
return Lists.newArrayList();
}
List<RelatedSchemaElement> relateSchemaElements = metric.getRelatedSchemaElements();
if (CollectionUtils.isEmpty(relateSchemaElements)) {
return Lists.newArrayList();
}
return relateSchemaElements.stream()
.filter(RelatedSchemaElement::isNecessary).map(RelatedSchemaElement::getDimensionId)
.collect(Collectors.toList());
}
private List<String> getDimensionFields(List<String> groupByFields, List<String> whereFields) {
List<String> dimensionFields = Lists.newArrayList();
if (!CollectionUtils.isEmpty(groupByFields)) {
dimensionFields.addAll(groupByFields);
}
if (!CollectionUtils.isEmpty(whereFields)) {
dimensionFields.addAll(whereFields);
}
return dimensionFields;
}
private String convertDimensionIdToName(Long id, SemanticSchema semanticSchema) {
SchemaElement schemaElement = semanticSchema.getElement(SchemaElementType.DIMENSION, id);
if (schemaElement == null) {
return null;
}
return schemaElement.getName();
}
private boolean checkInModelSchema(String name, SchemaElementType type, SemanticSchema semanticSchema) {
SchemaElement schemaElement = semanticSchema.getElementByName(type, name);
return schemaElement != null;
}
private boolean checkHasMetric(String correctSql, SemanticSchema semanticSchema) {
List<String> selectFields = SqlSelectHelper.getSelectFields(correctSql);
List<String> aggFields = SqlSelectHelper.getAggregateFields(correctSql);
List<String> collect = semanticSchema.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
for (String field : selectFields) {
if (collect.contains(field)) {
return true;
}
}
return !CollectionUtils.isEmpty(aggFields);
}
private static String removeFieldInSql(String sql, Set<String> metricToRemove,
Set<String> dimensionByToRemove, Set<String> whereFieldsToRemove) {
sql = SqlRemoveHelper.removeWhereCondition(sql, whereFieldsToRemove);
sql = SqlRemoveHelper.removeSelect(sql, metricToRemove);
sql = SqlRemoveHelper.removeSelect(sql, dimensionByToRemove);
sql = SqlRemoveHelper.removeGroupBy(sql, dimensionByToRemove);
sql = SqlRemoveHelper.removeNumberFilter(sql);
return sql;
}
}

View File

@@ -82,12 +82,17 @@ public class ChatConfigController {
return semanticInterpreter.getDomainList(user);
}
//Compatible with front-end
@GetMapping("/viewList")
public List<ViewResp> getViewList() {
//Compatible with front-end
return semanticInterpreter.getViewList(null);
}
@GetMapping("/viewList/{domainId}")
public List<ViewResp> getViewList(@PathVariable("domainId") Long domainId) {
return semanticInterpreter.getViewList(domainId);
}
@PostMapping("/dimension/page")
public PageInfo<DimensionResp> getDimension(@RequestBody PageDimensionReq pageDimensionReq) {
return semanticInterpreter.getDimensionPage(pageDimensionReq);

View File

@@ -102,10 +102,10 @@ public class SemanticService {
}
entityInfo.setViewInfo(viewInfo);
TagTypeDefaultConfig tagTypeDefaultConfig = viewSchema.getTagTypeDefaultConfig();
if (tagTypeDefaultConfig == null) {
if (tagTypeDefaultConfig == null || tagTypeDefaultConfig.getDefaultDisplayInfo() == null) {
return entityInfo;
}
List<DataInfo> dimensions = tagTypeDefaultConfig.getDimensionIds().stream()
List<DataInfo> dimensions = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> {
SchemaElement element = viewSchema.getElement(SchemaElementType.DIMENSION, id);
if (element == null) {
@@ -113,7 +113,7 @@ public class SemanticService {
}
return new DataInfo(element.getId().intValue(), element.getName(), element.getBizName(), null);
}).filter(Objects::nonNull).collect(Collectors.toList());
List<DataInfo> metrics = tagTypeDefaultConfig.getDimensionIds().stream()
List<DataInfo> metrics = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> {
SchemaElement element = viewSchema.getElement(SchemaElementType.METRIC, id);
if (element == null) {

View File

@@ -1,162 +0,0 @@
package com.tencent.supersonic.chat.server.processor;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.server.processor.parse.MetricCheckProcessor;
import java.util.List;
import java.util.Set;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class MetricCheckProcessorTest {
@Test
void testProcessCorrectSql_necessaryDimension_groupBy() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数) FROM 超音数 GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_necessaryDimension_where() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 where 部门 = 'HR' group by 用户名";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 "
+ "WHERE 部门 = 'HR' GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_groupBy() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 页面, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 部门";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_where() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 where 页面 = 'P1' group by 部门";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionNotDrillDown_necessaryDimension() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 页面, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT sum(访问次数) FROM 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_dimensionDrillDown() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 用户名, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名, 部门";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 用户名, 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 页面, 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 用户名";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "SELECT 页面, 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 页面, 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 访问次数 from 超音数";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 访问次数 from 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@Test
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
MetricCheckProcessor metricCheckPostProcessor = new MetricCheckProcessor();
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "SELECT count(*) FROM 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
/**
* 访问次数 drill down dimension is 用户名 and 部门
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
*/
private SemanticSchema mockModelSchema() {
ViewSchema modelSchema = new ViewSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC,
Lists.newArrayList(RelatedSchemaElement.builder().dimensionId(2L).isNecessary(false).build(),
RelatedSchemaElement.builder().dimensionId(1L).isNecessary(false).build())),
mockElement(2L, "访问用户数", SchemaElementType.METRIC,
Lists.newArrayList(RelatedSchemaElement.builder().dimensionId(2L).isNecessary(true).build()))
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return new SemanticSchema(Lists.newArrayList(modelSchema));
}
private SemanticSchema mockModelSchemaNoDimensionSetting() {
ViewSchema modelSchema = new ViewSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC, Lists.newArrayList()),
mockElement(2L, "访问用户数", SchemaElementType.METRIC, Lists.newArrayList())
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return new SemanticSchema(Lists.newArrayList(modelSchema));
}
private Set<SchemaElement> mockDimensions() {
return Sets.newHashSet(
mockElement(1L, "用户名", SchemaElementType.DIMENSION, Lists.newArrayList()),
mockElement(2L, "部门", SchemaElementType.DIMENSION, Lists.newArrayList()),
mockElement(3L, "页面", SchemaElementType.DIMENSION, Lists.newArrayList())
);
}
private SchemaElement mockElement(Long id, String name, SchemaElementType type,
List<RelatedSchemaElement> relateSchemaElements) {
return SchemaElement.builder().id(id).name(name).type(type)
.relatedSchemaElements(relateSchemaElements).build();
}
private SemanticParseInfo mockParseInfo(String correctSql) {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
return semanticParseInfo;
}
}

View File

@@ -1,13 +1,8 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.CaseExpression;
import net.sf.jsqlparser.expression.Expression;
@@ -38,6 +33,13 @@ import net.sf.jsqlparser.statement.select.SubSelect;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Sql Parser Select Helper
*/
@@ -365,6 +367,36 @@ public class SqlSelectHelper {
return new ArrayList<>(result);
}
public static List<String> getAggregateAsFields(String sql) {
List<PlainSelect> plainSelectList = getPlainSelect(sql);
Set<String> result = new HashSet<>();
for (PlainSelect plainSelect : plainSelectList) {
if (Objects.isNull(plainSelect)) {
continue;
}
List<SelectItem> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem;
if (expressionItem.getExpression() instanceof Function) {
Function function = (Function) expressionItem.getExpression();
Alias alias = expressionItem.getAlias();
if (alias != null && StringUtils.isNotBlank(alias.getName())) {
result.add(alias.getName());
} else {
if (Objects.nonNull(function.getParameters())
&& !CollectionUtils.isEmpty(function.getParameters().getExpressions())) {
String columnName = function.getParameters().getExpressions().get(0).toString();
result.add(columnName);
}
}
}
}
}
}
return new ArrayList<>(result);
}
public static boolean hasGroupBy(String sql) {
Select selectStatement = getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.headless.api.pojo;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
@Data
public class DefaultDisplayInfo {
//When displaying tag selection results, the information displayed by default
private List<Long> dimensionIds = new ArrayList<>();
private List<Long> metricIds = new ArrayList<>();
}

View File

@@ -2,15 +2,10 @@ package com.tencent.supersonic.headless.api.pojo;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
@Data
public class TagTypeDefaultConfig {
//When displaying tag selection results, the information displayed by default
private List<Long> dimensionIds = new ArrayList<>();
private List<Long> metricIds = new ArrayList<>();
private DefaultDisplayInfo defaultDisplayInfo;
//default time to filter tag selection results
private TimeDefaultConfig timeDefaultConfig;

View File

@@ -77,6 +77,13 @@ public class MetricResp extends SchemaItem {
.collect(Collectors.joining(","));
}
public List<DrillDownDimension> getDrillDownDimensions() {
if (relateDimension == null || CollectionUtils.isEmpty(relateDimension.getDrillDownDimensions())) {
return Lists.newArrayList();
}
return relateDimension.getDrillDownDimensions();
}
public String getDefaultAgg() {
if (metricDefineByMeasureParams != null
&& CollectionUtils.isNotEmpty(metricDefineByMeasureParams.getMeasures())) {

View File

@@ -35,4 +35,24 @@ public class SemanticSchemaResp {
}
public MetricSchemaResp getMetric(String bizName) {
return metrics.stream().filter(metric -> bizName.equalsIgnoreCase(metric.getBizName()))
.findFirst().orElse(null);
}
public MetricSchemaResp getMetric(Long id) {
return metrics.stream().filter(metric -> id.equals(metric.getId()))
.findFirst().orElse(null);
}
public DimSchemaResp getDimension(String bizName) {
return dimensions.stream().filter(dimension -> bizName.equalsIgnoreCase(dimension.getBizName()))
.findFirst().orElse(null);
}
public DimSchemaResp getDimension(Long id) {
return dimensions.stream().filter(dimension -> id.equals(dimension.getId()))
.findFirst().orElse(null);
}
}

View File

@@ -0,0 +1,150 @@
package com.tencent.supersonic.headless.server.aspect;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Aspect
@Component
@Slf4j
public class MetricDrillDownChecker {
@Around("execution(* com.tencent.supersonic.headless.core.parser.QueryParser.parse(..))")
public Object doAround(ProceedingJoinPoint joinPoint) throws Throwable {
Object[] objects = joinPoint.getArgs();
QueryStatement queryStatement = (QueryStatement) objects[0];
if (queryStatement.getParseSqlReq() == null) {
return joinPoint.proceed();
}
checkQuery(queryStatement.getSemanticSchemaResp(), queryStatement.getParseSqlReq().getSql());
return joinPoint.proceed();
}
public void checkQuery(SemanticSchemaResp semanticSchemaResp, String sql) {
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sql);
List<String> metricFields = SqlSelectHelper.getAggregateAsFields(sql);
List<String> whereFields = SqlSelectHelper.getWhereFields(sql);
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(sql)) {
return;
}
for (String metricName : metricFields) {
MetricSchemaResp metric = semanticSchemaResp.getMetric(metricName);
List<DimensionResp> necessaryDimensions = getNecessaryDimensions(metric, semanticSchemaResp);
List<DimensionResp> dimensionsMissing = getNecessaryDimensionMissing(necessaryDimensions, dimensionFields);
if (!CollectionUtils.isEmpty(dimensionsMissing)) {
String errMsg = String.format("指标:%s 缺失必要维度:%s", metric.getName(),
dimensionsMissing.stream().map(DimensionResp::getName).collect(Collectors.toList()));
throw new InvalidArgumentException(errMsg);
}
}
for (String dimensionBizName : groupByFields) {
if (TimeDimensionEnum.containsTimeDimension(dimensionBizName)) {
continue;
}
List<MetricResp> metricResps = getMetrics(metricFields, semanticSchemaResp);
if (!checkDrillDownDimension(dimensionBizName, metricResps, semanticSchemaResp)) {
DimSchemaResp dimSchemaResp = semanticSchemaResp.getDimension(dimensionBizName);
String errMsg = String.format("维度:%s, 不在当前查询指标的下钻维度配置中, 请检查", dimSchemaResp.getName());
throw new InvalidArgumentException(errMsg);
}
}
}
/**
* To check whether the dimension bound to the metric exists,
* eg: metric like UV is calculated in a certain dimension, it cannot be used on other dimensions.
*/
private List<DimensionResp> getNecessaryDimensionMissing(List<DimensionResp> necessaryDimensions,
List<String> dimensionFields) {
return necessaryDimensions.stream()
.filter(dimension -> !dimensionFields.contains(dimension.getBizName()))
.collect(Collectors.toList());
}
/**
* To check whether the dimension can drill down the metric,
* eg: some descriptive dimensions are not suitable as drill-down dimensions
*/
private boolean checkDrillDownDimension(String dimensionName,
List<MetricResp> metricResps,
SemanticSchemaResp semanticSchemaResp) {
if (CollectionUtils.isEmpty(metricResps)) {
return true;
}
List<String> relateDimensions = metricResps.stream()
.filter(metric -> !CollectionUtils.isEmpty(metric.getDrillDownDimensions()))
.map(metric -> metric.getDrillDownDimensions().stream()
.map(DrillDownDimension::getDimensionId).collect(Collectors.toList()))
.flatMap(Collection::stream)
.map(id -> convertDimensionIdToBizName(id, semanticSchemaResp))
.filter(Objects::nonNull)
.collect(Collectors.toList());
//if no metric has drill down dimension, return true
if (CollectionUtils.isEmpty(relateDimensions)) {
return true;
}
//if this dimension not in relate drill-down dimensions, return false
return relateDimensions.contains(dimensionName);
}
private List<DimensionResp> getNecessaryDimensions(MetricSchemaResp metric, SemanticSchemaResp semanticSchemaResp) {
if (metric == null) {
return Lists.newArrayList();
}
List<DrillDownDimension> drillDownDimensions = metric.getDrillDownDimensions();
if (CollectionUtils.isEmpty(drillDownDimensions)) {
return Lists.newArrayList();
}
return drillDownDimensions.stream()
.filter(DrillDownDimension::isNecessary).map(DrillDownDimension::getDimensionId)
.map(semanticSchemaResp::getDimension)
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
private List<String> getDimensionFields(List<String> groupByFields, List<String> whereFields) {
List<String> dimensionFields = Lists.newArrayList();
if (!CollectionUtils.isEmpty(groupByFields)) {
dimensionFields.addAll(groupByFields);
}
if (!CollectionUtils.isEmpty(whereFields)) {
dimensionFields.addAll(whereFields);
}
return dimensionFields;
}
private List<MetricResp> getMetrics(List<String> metricFields, SemanticSchemaResp semanticSchemaResp) {
return semanticSchemaResp.getMetrics().stream()
.filter(metricSchemaResp -> metricFields.contains(metricSchemaResp.getBizName()))
.collect(Collectors.toList());
}
private String convertDimensionIdToBizName(Long id, SemanticSchemaResp semanticSchemaResp) {
DimSchemaResp dimension = semanticSchemaResp.getDimension(id);
if (dimension == null) {
return null;
}
return dimension.getBizName();
}
}

View File

@@ -0,0 +1,88 @@
package com.tencent.supersonic.headless.server.aspect;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.server.utils.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertThrows;
@Slf4j
public class MetricDrillDownCheckerTest {
@Test
void test_groupBy_in_drillDownDimension() {
MetricDrillDownChecker metricDrillDownChecker = new MetricDrillDownChecker();
String sql = "select user_name, sum(pv) from t_1 group by user_name";
SemanticSchemaResp semanticSchemaResp = mockModelSchemaResp();
metricDrillDownChecker.checkQuery(semanticSchemaResp, sql);
}
@Test
void test_groupBy_not_in_drillDownDimension() {
MetricDrillDownChecker metricDrillDownChecker = new MetricDrillDownChecker();
String sql = "select page, sum(pv) from t_1 group by page";
SemanticSchemaResp semanticSchemaResp = mockModelSchemaResp();
assertThrows(InvalidArgumentException.class,
() -> metricDrillDownChecker.checkQuery(semanticSchemaResp, sql));
}
@Test
void test_groupBy_not_in_necessary_dimension() {
MetricDrillDownChecker metricDrillDownChecker = new MetricDrillDownChecker();
String sql = "select user_name, count(distinct uv) from t_1 group by user_name";
SemanticSchemaResp semanticSchemaResp = mockModelSchemaResp();
assertThrows(InvalidArgumentException.class,
() -> metricDrillDownChecker.checkQuery(semanticSchemaResp, sql));
}
@Test
void test_groupBy_no_necessary_dimension_setting() {
MetricDrillDownChecker metricDrillDownChecker = new MetricDrillDownChecker();
String sql = "select user_name, page, count(distinct uv) from t_1 group by user_name,page";
SemanticSchemaResp semanticSchemaResp = mockModelSchemaNoDimensionSetting();
metricDrillDownChecker.checkQuery(semanticSchemaResp, sql);
}
private SemanticSchemaResp mockModelSchemaResp() {
SemanticSchemaResp semanticSchemaResp = new SemanticSchemaResp();
semanticSchemaResp.setMetrics(mockMetrics());
semanticSchemaResp.setDimensions(mockDimensions());
return semanticSchemaResp;
}
private SemanticSchemaResp mockModelSchemaNoDimensionSetting() {
SemanticSchemaResp semanticSchemaResp = new SemanticSchemaResp();
List<MetricSchemaResp> metricSchemaResps = Lists.newArrayList(mockMetricsNoDrillDownSetting());
semanticSchemaResp.setMetrics(metricSchemaResps);
semanticSchemaResp.setDimensions(mockDimensions());
return semanticSchemaResp;
}
private List<DimSchemaResp> mockDimensions() {
return Lists.newArrayList(DataUtils.mockDimension(1L, "user_name", "用户名"),
DataUtils.mockDimension(2L, "department", "部门"),
DataUtils.mockDimension(3L, "page", "页面"));
}
private List<MetricSchemaResp> mockMetrics() {
return Lists.newArrayList(
DataUtils.mockMetric(1L, "pv", "访问次数",
Lists.newArrayList(new DrillDownDimension(1L), new DrillDownDimension(2L))),
DataUtils.mockMetric(2L, "uv", "访问用户数",
Lists.newArrayList(new DrillDownDimension(2L, true))));
}
private List<MetricSchemaResp> mockMetricsNoDrillDownSetting() {
return Lists.newArrayList(
DataUtils.mockMetric(1L, "pv", Lists.newArrayList()),
DataUtils.mockMetric(2L, "uv", Lists.newArrayList()));
}
}

View File

@@ -0,0 +1,45 @@
package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import java.util.List;
public class DataUtils {
public static DimSchemaResp mockDimension(Long id, String bizName, String name) {
DimSchemaResp dimSchemaResp = new DimSchemaResp();
dimSchemaResp.setId(id);
dimSchemaResp.setBizName(bizName);
dimSchemaResp.setName(name);
return dimSchemaResp;
}
public static MetricSchemaResp mockMetric(Long id, String bizName) {
MetricSchemaResp metricSchemaResp = new MetricSchemaResp();
metricSchemaResp.setId(id);
metricSchemaResp.setBizName(bizName);
RelateDimension relateDimension = new RelateDimension();
metricSchemaResp.setRelateDimension(relateDimension);
return metricSchemaResp;
}
public static MetricSchemaResp mockMetric(Long id, String bizName, String name,
List<DrillDownDimension> drillDownDimensions) {
MetricSchemaResp metricSchemaResp = new MetricSchemaResp();
metricSchemaResp.setId(id);
metricSchemaResp.setName(name);
metricSchemaResp.setBizName(bizName);
metricSchemaResp.setRelateDimension(RelateDimension.builder()
.drillDownDimensions(drillDownDimensions).build());
return metricSchemaResp;
}
public static MetricSchemaResp mockMetric(Long id, String bizName,
List<DrillDownDimension> drillDownDimensions) {
return mockMetric(id, bizName, null, drillDownDimensions);
}
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.Dim;
@@ -231,8 +232,10 @@ public class BenchMarkDemoDataLoader {
tagTimeDefaultConfig.setTimeMode(TimeMode.LAST);
tagTimeDefaultConfig.setUnit(7);
tagTypeDefaultConfig.setTimeDefaultConfig(tagTimeDefaultConfig);
tagTypeDefaultConfig.setDimensionIds(Lists.newArrayList());
tagTypeDefaultConfig.setMetricIds(Lists.newArrayList());
DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo();
defaultDisplayInfo.setDimensionIds(Lists.newArrayList());
defaultDisplayInfo.setMetricIds(Lists.newArrayList());
tagTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo);
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.RECENT);

View File

@@ -23,8 +23,6 @@ import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
@@ -33,6 +31,9 @@ import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import java.util.Arrays;
import java.util.List;
@Component
@Slf4j
@Order(3)
@@ -164,7 +165,7 @@ public class ChatDemoLoader implements CommandLineRunner {
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setViewIds(Lists.newArrayList(-1L));
ruleQueryTool.setViewIds(Lists.newArrayList(1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) {
@@ -190,7 +191,7 @@ public class ChatDemoLoader implements CommandLineRunner {
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setViewIds(Lists.newArrayList(-1L));
ruleQueryTool.setViewIds(Lists.newArrayList(2L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
agentConfig.getTools().add(ruleQueryTool);

View File

@@ -14,6 +14,7 @@ import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
@@ -146,7 +147,7 @@ public class ModelDemoDataLoader {
public void addModel_1() throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("超音数用户部门");
modelReq.setName("用户部门");
modelReq.setBizName("user_department");
modelReq.setDescription("用户部门信息");
modelReq.setDatabaseId(1L);
@@ -178,9 +179,9 @@ public class ModelDemoDataLoader {
public void addModel_2() throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("超音数PVUV统计");
modelReq.setName("PVUV统计");
modelReq.setBizName("s2_pv_uv_statis");
modelReq.setDescription("超音数PVUV统计");
modelReq.setDescription("PVUV统计");
modelReq.setDatabaseId(1L);
modelReq.setViewers(Arrays.asList("admin", "tom", "jack"));
modelReq.setViewOrgs(Collections.singletonList("1"));
@@ -384,6 +385,7 @@ public class ModelDemoDataLoader {
metricReq.setId(1L);
metricReq.setName("访问次数");
metricReq.setBizName("pv");
metricReq.setDescription("一段时间内用户的访问次数");
MetricDefineByMeasureParams metricTypeParams = new MetricDefineByMeasureParams();
metricTypeParams.setExpr("s2_pv_uv_statis_pv");
List<MeasureParam> measures = new ArrayList<>();
@@ -404,7 +406,7 @@ public class ModelDemoDataLoader {
metricReq.setBizName("uv");
metricReq.setSensitiveLevel(SensitiveLevelEnum.LOW.getCode());
metricReq.setDescription("访问的用户个数");
metricReq.setAlias("UV");
metricReq.setAlias("UV,访问人数");
MetricDefineByFieldParams metricTypeParams = new MetricDefineByFieldParams();
metricTypeParams.setExpr("count(distinct user_id)");
List<FieldParam> fieldParams = new ArrayList<>();
@@ -491,8 +493,10 @@ public class ModelDemoDataLoader {
tagTimeDefaultConfig.setTimeMode(TimeMode.LAST);
tagTimeDefaultConfig.setUnit(7);
tagTypeDefaultConfig.setTimeDefaultConfig(tagTimeDefaultConfig);
tagTypeDefaultConfig.setDimensionIds(Lists.newArrayList(4L, 5L, 6L, 7L));
tagTypeDefaultConfig.setMetricIds(Lists.newArrayList(5L));
DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo();
defaultDisplayInfo.setDimensionIds(Lists.newArrayList(4L, 5L, 6L, 7L));
defaultDisplayInfo.setMetricIds(Lists.newArrayList(5L));
tagTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo);
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.RECENT);

View File

@@ -20,7 +20,6 @@ com.tencent.supersonic.chat.core.corrector.SemanticCorrector=\
com.tencent.supersonic.chat.core.corrector.HavingCorrector
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
com.tencent.supersonic.chat.server.processor.parse.MetricCheckProcessor, \
com.tencent.supersonic.chat.server.processor.parse.ParseInfoProcessor, \
com.tencent.supersonic.chat.server.processor.parse.QueryRankProcessor, \
com.tencent.supersonic.chat.server.processor.parse.EntityInfoProcessor, \

View File

@@ -19,7 +19,6 @@ com.tencent.supersonic.chat.core.corrector.SemanticCorrector=\
com.tencent.supersonic.chat.core.corrector.HavingCorrector
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
com.tencent.supersonic.chat.server.processor.parse.MetricCheckProcessor, \
com.tencent.supersonic.chat.server.processor.parse.ParseInfoProcessor, \
com.tencent.supersonic.chat.server.processor.parse.QueryRankProcessor, \
com.tencent.supersonic.chat.server.processor.parse.EntityInfoProcessor, \