(improvement)(Headless) Support specifying metrics, tags, and dimensions query modes in the Map phase (#959)

This commit is contained in:
lexluo09
2024-04-28 17:02:32 +08:00
committed by GitHub
parent 83b80e35f0
commit a6724f886b
10 changed files with 91 additions and 23 deletions

View File

@@ -0,0 +1,8 @@
package com.tencent.supersonic.headless.api.pojo;
public enum QueryDataType {
METRIC,
DIMENSION,
TAG,
ALL
}

View File

@@ -31,6 +31,7 @@ public class SchemaElement implements Serializable {
private String defaultAgg; private String defaultAgg;
private String dataFormatType; private String dataFormatType;
private double order; private double order;
private int isTag;
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.apache.commons.collections4.CollectionUtils;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@@ -31,18 +30,4 @@ public class SchemaMapInfo {
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) { public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
dataSetElementMatches.put(dataSet, elementMatches); dataSetElementMatches.put(dataSet, elementMatches);
} }
public Map<Long, String> generateDataSetMap() {
Map<Long, String> dataSetIdToName = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
List<SchemaElementMatch> values = entry.getValue();
if (CollectionUtils.isNotEmpty(values)) {
SchemaElementMatch schemaElementMatch = values.get(0);
String dataSetName = schemaElementMatch.getElement().getDataSetName();
dataSetIdToName.put(entry.getKey(), dataSetName);
}
}
return dataSetIdToName;
}
} }

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.request; package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
@@ -15,4 +16,5 @@ public class QueryMapReq {
private User user; private User user;
private Integer topN = 10; private Integer topN = 10;
private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private QueryDataType queryDataType = QueryDataType.ALL;
} }

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import lombok.Data; import lombok.Data;
@@ -20,4 +21,5 @@ public class QueryReq {
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL;
} }

View File

@@ -16,6 +16,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Slf4j @Slf4j
@@ -30,13 +31,51 @@ public abstract class BaseMapper implements SchemaMapper {
try { try {
doMap(queryContext); doMap(queryContext);
filter(queryContext);
} catch (Exception e) { } catch (Exception e) {
log.error("work error", e); log.error("work error", e);
} }
long cost = System.currentTimeMillis() - startTime; long cost = System.currentTimeMillis() - startTime;
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getDataSetElementMatches());
queryContext.getMapInfo().getDataSetElementMatches()); }
private void filter(QueryContext queryContext) {
switch (queryContext.getQueryDataType()) {
case TAG:
filterByQueryDataType(queryContext, element -> !(element.getIsTag() > 0));
break;
case METRIC:
filterByQueryDataType(queryContext, element -> !SchemaElementType.METRIC.equals(element.getType()));
break;
case DIMENSION:
filterByQueryDataType(queryContext, element -> {
boolean isDimensionOrValue = SchemaElementType.DIMENSION.equals(element.getType())
|| SchemaElementType.VALUE.equals(element.getType());
return !isDimensionOrValue;
});
break;
case ALL:
default:
break;
}
}
private static void filterByQueryDataType(QueryContext queryContext, Predicate<SchemaElement> needRemovePredicate) {
queryContext.getMapInfo().getDataSetElementMatches().values().stream().forEach(
schemaElementMatches -> schemaElementMatches.removeIf(
schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType type = element.getType();
if (SchemaElementType.ENTITY.equals(type) || SchemaElementType.DATASET.equals(type)
|| SchemaElementType.ID.equals(type)) {
return false;
}
return needRemovePredicate.test(element);
}
));
} }
public abstract void doMap(QueryContext queryContext); public abstract void doMap(QueryContext queryContext);
@@ -107,4 +146,4 @@ public abstract class BaseMapper implements SchemaMapper {
} }
return element.getAlias(); return element.getAlias();
} }
} }

View File

@@ -89,7 +89,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) .filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new)); .collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after isSimilarity parseResults:{}", hanlpMapResults); log.info("detectSegment:{},after isSimilarity parseResults:{}", detectSegment, hanlpMapResults);
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> { hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
parseResult.setOffset(offset); parseResult.setOffset(offset);

View File

@@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
@@ -45,6 +46,7 @@ public class QueryContext {
private SemanticSchema semanticSchema; private SemanticSchema semanticSchema;
@JsonIgnore @JsonIgnore
private WorkflowState workflowState; private WorkflowState workflowState;
private QueryDataType queryDataType = QueryDataType.ALL;
public List<SemanticQuery> getCandidateQueries() { public List<SemanticQuery> getCandidateQueries() {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);

View File

@@ -12,6 +12,9 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@@ -21,10 +24,6 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
public class DataSetSchemaBuilder { public class DataSetSchemaBuilder {
public static DataSetSchema build(DataSetSchemaResp resp) { public static DataSetSchema build(DataSetSchemaResp resp) {
@@ -77,6 +76,7 @@ public class DataSetSchemaBuilder {
.type(SchemaElementType.TAG) .type(SchemaElementType.TAG)
.useCnt(metric.getUseCnt()) .useCnt(metric.getUseCnt())
.alias(alias) .alias(alias)
.isTag(metric.getIsTag())
.build(); .build();
tags.add(tagToAdd); tags.add(tagToAdd);
} }
@@ -109,6 +109,7 @@ public class DataSetSchemaBuilder {
.useCnt(dim.getUseCnt()) .useCnt(dim.getUseCnt())
.alias(alias) .alias(alias)
.schemaValueMaps(schemaValueMaps) .schemaValueMaps(schemaValueMaps)
.isTag(dim.getIsTag())
.build(); .build();
tags.add(tagToAdd); tags.add(tagToAdd);
} }
@@ -157,6 +158,7 @@ public class DataSetSchemaBuilder {
.useCnt(dim.getUseCnt()) .useCnt(dim.getUseCnt())
.alias(alias) .alias(alias)
.schemaValueMaps(schemaValueMaps) .schemaValueMaps(schemaValueMaps)
.isTag(dim.getIsTag())
.build(); .build();
dimensions.add(dimToAdd); dimensions.add(dimToAdd);
} }
@@ -188,6 +190,7 @@ public class DataSetSchemaBuilder {
.type(SchemaElementType.VALUE) .type(SchemaElementType.VALUE)
.useCnt(dim.getUseCnt()) .useCnt(dim.getUseCnt())
.alias(new ArrayList<>(Arrays.asList(dimValueAlias.toArray(new String[0])))) .alias(new ArrayList<>(Arrays.asList(dimValueAlias.toArray(new String[0]))))
.isTag(dim.getIsTag())
.build(); .build();
dimensionValues.add(dimValueToAdd); dimensionValues.add(dimValueToAdd);
} }
@@ -213,6 +216,7 @@ public class DataSetSchemaBuilder {
.relatedSchemaElements(getRelateSchemaElement(metric)) .relatedSchemaElements(getRelateSchemaElement(metric))
.defaultAgg(metric.getDefaultAgg()) .defaultAgg(metric.getDefaultAgg())
.dataFormatType(metric.getDataFormatType()) .dataFormatType(metric.getDataFormatType())
.isTag(metric.getIsTag())
.build(); .build();
metrics.add(metricToAdd); metrics.add(metricToAdd);

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless; package com.tencent.supersonic.headless;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.server.service.MetaDiscoveryService; import com.tencent.supersonic.headless.server.service.MetaDiscoveryService;
@@ -28,4 +29,28 @@ public class MetaDiscoveryTest extends BaseTest {
Assert.assertNotEquals(0, mapMeta.getMapFields()); Assert.assertNotEquals(0, mapMeta.getMapFields());
Assert.assertNotEquals(0, mapMeta.getTopFields()); Assert.assertNotEquals(0, mapMeta.getTopFields());
} }
@Test
public void testGetMapMeta2() throws Exception {
QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText("风格为流行的艺人");
queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser());
queryMapReq.setDataSetNames(Collections.singletonList("艺人库"));
queryMapReq.setQueryDataType(QueryDataType.TAG);
MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq);
Assert.assertNotNull(mapMeta);
}
@Test
public void testGetMapMeta3() throws Exception {
QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText("超音数访问次数最高的部门");
queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser());
queryMapReq.setDataSetNames(Collections.singletonList("超音数"));
queryMapReq.setQueryDataType(QueryDataType.METRIC);
MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq);
Assert.assertNotNull(mapMeta);
}
} }