From a6724f886b7a7b5d88ec9e9df27777e2a5b798bb Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sun, 28 Apr 2024 17:02:32 +0800 Subject: [PATCH] (improvement)(Headless) Support specifying metrics, tags, and dimensions query modes in the Map phase (#959) --- .../headless/api/pojo/QueryDataType.java | 8 ++++ .../headless/api/pojo/SchemaElement.java | 1 + .../headless/api/pojo/SchemaMapInfo.java | 15 ------- .../api/pojo/request/QueryMapReq.java | 2 + .../headless/api/pojo/request/QueryReq.java | 2 + .../headless/core/chat/mapper/BaseMapper.java | 45 +++++++++++++++++-- .../chat/mapper/HanlpDictMatchStrategy.java | 2 +- .../headless/core/pojo/QueryContext.java | 2 + .../server/utils/DataSetSchemaBuilder.java | 12 +++-- .../headless/MetaDiscoveryTest.java | 25 +++++++++++ 10 files changed, 91 insertions(+), 23 deletions(-) create mode 100644 headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryDataType.java diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryDataType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryDataType.java new file mode 100644 index 000000000..08dec1bdf --- /dev/null +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryDataType.java @@ -0,0 +1,8 @@ +package com.tencent.supersonic.headless.api.pojo; + +public enum QueryDataType { + METRIC, + DIMENSION, + TAG, + ALL +} \ No newline at end of file diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index d680a3e22..398169d78 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -31,6 +31,7 @@ public class SchemaElement implements Serializable { private String defaultAgg; private String dataFormatType; private double order; + private int isTag; @Override public boolean equals(Object o) { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java index ee7c9d4e0..483b2a428 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.api.pojo; import com.google.common.collect.Lists; -import org.apache.commons.collections4.CollectionUtils; import java.util.HashMap; import java.util.List; @@ -31,18 +30,4 @@ public class SchemaMapInfo { public void setMatchedElements(Long dataSet, List elementMatches) { dataSetElementMatches.put(dataSet, elementMatches); } - - public Map generateDataSetMap() { - Map dataSetIdToName = new HashMap<>(); - for (Map.Entry> entry : dataSetElementMatches.entrySet()) { - List values = entry.getValue(); - if (CollectionUtils.isNotEmpty(values)) { - SchemaElementMatch schemaElementMatch = values.get(0); - String dataSetName = schemaElementMatch.getElement().getDataSetName(); - dataSetIdToName.put(entry.getKey(), dataSetName); - } - } - return dataSetIdToName; - } - } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryMapReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryMapReq.java index 6b53b06bf..4eac5510a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryMapReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryMapReq.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import lombok.Data; import lombok.ToString; @@ -15,4 +16,5 @@ public class QueryMapReq { private User user; private Integer topN = 10; private MapModeEnum mapModeEnum = MapModeEnum.STRICT; + private QueryDataType queryDataType = QueryDataType.ALL; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java index 6ee1eb61b..55f360b42 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; +import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import lombok.Data; @@ -20,4 +21,5 @@ public class QueryReq { private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private SchemaMapInfo mapInfo = new SchemaMapInfo(); + private QueryDataType queryDataType = QueryDataType.ALL; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java index ac28b08df..99c53d3b4 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java @@ -16,6 +16,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Predicate; import java.util.stream.Collectors; @Slf4j @@ -30,13 +31,51 @@ public abstract class BaseMapper implements SchemaMapper { try { doMap(queryContext); + filter(queryContext); } catch (Exception e) { log.error("work error", e); } long cost = System.currentTimeMillis() - startTime; - log.info("after {},cost:{},mapInfo:{}", simpleName, cost, - queryContext.getMapInfo().getDataSetElementMatches()); + log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getDataSetElementMatches()); + } + + private void filter(QueryContext queryContext) { + + switch (queryContext.getQueryDataType()) { + case TAG: + filterByQueryDataType(queryContext, element -> !(element.getIsTag() > 0)); + break; + case METRIC: + filterByQueryDataType(queryContext, element -> !SchemaElementType.METRIC.equals(element.getType())); + break; + case DIMENSION: + filterByQueryDataType(queryContext, element -> { + boolean isDimensionOrValue = SchemaElementType.DIMENSION.equals(element.getType()) + || SchemaElementType.VALUE.equals(element.getType()); + return !isDimensionOrValue; + }); + break; + case ALL: + default: + break; + } + + } + + private static void filterByQueryDataType(QueryContext queryContext, Predicate needRemovePredicate) { + queryContext.getMapInfo().getDataSetElementMatches().values().stream().forEach( + schemaElementMatches -> schemaElementMatches.removeIf( + schemaElementMatch -> { + SchemaElement element = schemaElementMatch.getElement(); + SchemaElementType type = element.getType(); + if (SchemaElementType.ENTITY.equals(type) || SchemaElementType.DATASET.equals(type) + || SchemaElementType.ID.equals(type)) { + return false; + } + return needRemovePredicate.test(element); + } + )); } public abstract void doMap(QueryContext queryContext); @@ -107,4 +146,4 @@ public abstract class BaseMapper implements SchemaMapper { } return element.getAlias(); } -} +} \ No newline at end of file diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java index 44184ee78..cf4a44211 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java @@ -89,7 +89,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) .collect(Collectors.toCollection(LinkedHashSet::new)); - log.info("after isSimilarity parseResults:{}", hanlpMapResults); + log.info("detectSegment:{},after isSimilarity parseResults:{}", detectSegment, hanlpMapResults); hanlpMapResults = hanlpMapResults.stream().map(parseResult -> { parseResult.setOffset(offset); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java index 0ace7ecf1..dfb0ce50e 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; @@ -45,6 +46,7 @@ public class QueryContext { private SemanticSchema semanticSchema; @JsonIgnore private WorkflowState workflowState; + private QueryDataType queryDataType = QueryDataType.ALL; public List getCandidateQueries() { OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java index 5949a5496..0e27d6581 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java @@ -12,6 +12,9 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap; import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; +import org.apache.logging.log4j.util.Strings; +import org.springframework.beans.BeanUtils; +import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.Arrays; @@ -21,10 +24,6 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; -import org.apache.logging.log4j.util.Strings; -import org.springframework.beans.BeanUtils; -import org.springframework.util.CollectionUtils; - public class DataSetSchemaBuilder { public static DataSetSchema build(DataSetSchemaResp resp) { @@ -77,6 +76,7 @@ public class DataSetSchemaBuilder { .type(SchemaElementType.TAG) .useCnt(metric.getUseCnt()) .alias(alias) + .isTag(metric.getIsTag()) .build(); tags.add(tagToAdd); } @@ -109,6 +109,7 @@ public class DataSetSchemaBuilder { .useCnt(dim.getUseCnt()) .alias(alias) .schemaValueMaps(schemaValueMaps) + .isTag(dim.getIsTag()) .build(); tags.add(tagToAdd); } @@ -157,6 +158,7 @@ public class DataSetSchemaBuilder { .useCnt(dim.getUseCnt()) .alias(alias) .schemaValueMaps(schemaValueMaps) + .isTag(dim.getIsTag()) .build(); dimensions.add(dimToAdd); } @@ -188,6 +190,7 @@ public class DataSetSchemaBuilder { .type(SchemaElementType.VALUE) .useCnt(dim.getUseCnt()) .alias(new ArrayList<>(Arrays.asList(dimValueAlias.toArray(new String[0])))) + .isTag(dim.getIsTag()) .build(); dimensionValues.add(dimValueToAdd); } @@ -213,6 +216,7 @@ public class DataSetSchemaBuilder { .relatedSchemaElements(getRelateSchemaElement(metric)) .defaultAgg(metric.getDefaultAgg()) .dataFormatType(metric.getDataFormatType()) + .isTag(metric.getIsTag()) .build(); metrics.add(metricToAdd); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/MetaDiscoveryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/MetaDiscoveryTest.java index faff00179..ce12b5056 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/MetaDiscoveryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/MetaDiscoveryTest.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.server.service.MetaDiscoveryService; @@ -28,4 +29,28 @@ public class MetaDiscoveryTest extends BaseTest { Assert.assertNotEquals(0, mapMeta.getMapFields()); Assert.assertNotEquals(0, mapMeta.getTopFields()); } + + @Test + public void testGetMapMeta2() throws Exception { + QueryMapReq queryMapReq = new QueryMapReq(); + queryMapReq.setQueryText("风格为流行的艺人"); + queryMapReq.setTopN(10); + queryMapReq.setUser(User.getFakeUser()); + queryMapReq.setDataSetNames(Collections.singletonList("艺人库")); + queryMapReq.setQueryDataType(QueryDataType.TAG); + MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq); + Assert.assertNotNull(mapMeta); + } + + @Test + public void testGetMapMeta3() throws Exception { + QueryMapReq queryMapReq = new QueryMapReq(); + queryMapReq.setQueryText("超音数访问次数最高的部门"); + queryMapReq.setTopN(10); + queryMapReq.setUser(User.getFakeUser()); + queryMapReq.setDataSetNames(Collections.singletonList("超音数")); + queryMapReq.setQueryDataType(QueryDataType.METRIC); + MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq); + Assert.assertNotNull(mapMeta); + } }