From 82f56109acc59ef4616d123543f44b6d2fd7dd47 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Wed, 13 Mar 2024 11:25:50 +0800 Subject: [PATCH] (improvement)(Chat) add tag filter in parse (#813) --- .../headless/core/chat/mapper/BaseMapper.java | 8 +++--- .../parser/rule/ContextInheritParser.java | 2 +- .../core/chat/parser/rule/RuleSqlParser.java | 25 ++++++++++++++--- .../chat/query/rule/RuleSemanticQuery.java | 28 +++++++++++++++++-- .../service/impl/DataSetServiceImpl.java | 6 +++- 5 files changed, 57 insertions(+), 12 deletions(-) diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java index b5af9d979..f6721d8b5 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java @@ -41,11 +41,11 @@ public abstract class BaseMapper implements SchemaMapper { public abstract void doMap(QueryContext queryContext); - public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) { - Map> modelElementMatches = schemaMap.getDataSetElementMatches(); - List schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>()); + public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) { + Map> dataSetElementMatches = schemaMap.getDataSetElementMatches(); + List schemaElementMatches = dataSetElementMatches.putIfAbsent(dataSetId, new ArrayList<>()); if (schemaElementMatches == null) { - schemaElementMatches = modelElementMatches.get(modelId); + schemaElementMatches = dataSetElementMatches.get(dataSetId); } //remove duplication AtomicBoolean needAddNew = new AtomicBoolean(true); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/ContextInheritParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/ContextInheritParser.java index b002e8b4d..bd79b2613 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/ContextInheritParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/ContextInheritParser.java @@ -67,7 +67,7 @@ public class ContextInheritParser implements SemanticParser { } elementMatches.addAll(matchesToInherit); - List queries = RuleSemanticQuery.resolve(elementMatches, queryContext); + List queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext); for (RuleSemanticQuery query : queries) { query.fillParseInfo(queryContext, chatContext); if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), queryContext)) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/RuleSqlParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/RuleSqlParser.java index ddee89d04..f3fa53f26 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/RuleSqlParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/rule/RuleSqlParser.java @@ -1,14 +1,17 @@ package com.tencent.supersonic.headless.core.chat.parser.rule; +import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.core.chat.parser.SemanticParser; +import com.tencent.supersonic.headless.core.chat.query.QueryManager; +import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.core.pojo.ChatContext; import com.tencent.supersonic.headless.core.pojo.QueryContext; -import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery; -import lombok.extern.slf4j.Slf4j; import java.util.Arrays; import java.util.List; +import lombok.extern.slf4j.Slf4j; /** * RuleSqlParser resolves a specific SemanticQuery according to co-appearance @@ -29,13 +32,27 @@ public class RuleSqlParser implements SemanticParser { // iterate all schemaElementMatches to resolve query mode for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) { List elementMatches = mapInfo.getMatchedElements(dataSetId); - List queries = RuleSemanticQuery.resolve(elementMatches, queryContext); + List queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext); for (RuleSemanticQuery query : queries) { query.fillParseInfo(queryContext, chatContext); - queryContext.getCandidateQueries().add(query); + SemanticParseInfo parseInfo = query.getParseInfo(); + QueryType queryType = queryContext.getQueryType(parseInfo.getDataSetId()); + if (isRightQuery(parseInfo, queryType)) { + queryContext.getCandidateQueries().add(query); + } } } auxiliaryParsers.stream().forEach(p -> p.parse(queryContext, chatContext)); } + + private boolean isRightQuery(SemanticParseInfo parseInfo, QueryType queryType) { + if (QueryType.TAG.equals(queryType) && QueryManager.isTagQuery(parseInfo.getQueryMode())) { + return true; + } + if (QueryType.METRIC.equals(queryType) && QueryManager.isMetricQuery(parseInfo.getQueryMode())) { + return true; + } + return false; + } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/rule/RuleSemanticQuery.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/rule/RuleSemanticQuery.java index fcc64889d..cd841ea9a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/rule/RuleSemanticQuery.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/rule/RuleSemanticQuery.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.query.rule; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; @@ -96,7 +97,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) { Set dataSetIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement) .map(SchemaElement::getDataSet).collect(Collectors.toSet()); - parseInfo.setDataSet(semanticSchema.getDataSet(dataSetIds.iterator().next())); + Long dataSetId = dataSetIds.iterator().next(); + parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId)); Map> dim2Values = new HashMap<>(); Map> id2Values = new HashMap<>(); Map> tag2Values = new HashMap<>(); @@ -114,6 +116,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { case VALUE: addToValues(semanticSchema, SchemaElementType.DIMENSION, dim2Values, schemaMatch); break; + case TAG: case DIMENSION: parseInfo.getDimensions().add(element); break; @@ -214,9 +217,10 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { this.parseInfo = parseInfo; } - public static List resolve(List candidateElementMatches, + public static List resolve(Long dataSetId, List candidateElementMatches, QueryContext queryContext) { List matchedQueries = new ArrayList<>(); + candidateElementMatches = filterByQueryType(dataSetId, candidateElementMatches, queryContext); for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) { List matches = semanticQuery.match(candidateElementMatches, queryContext); @@ -231,6 +235,26 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { return matchedQueries; } + private static List filterByQueryType(Long dataSetId, + List candidateElementMatches, QueryContext queryContext) { + QueryType queryType = queryContext.getQueryType(dataSetId); + if (QueryType.TAG.equals(queryType)) { + candidateElementMatches = candidateElementMatches.stream() + .filter(elementMatch -> !(SchemaElementType.METRIC.equals(elementMatch.getElement().getType()) + || SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()) + || SchemaElementType.VALUE.equals(elementMatch.getElement().getType())) + ) + .collect(Collectors.toList()); + } + if (QueryType.METRIC.equals(queryType)) { + candidateElementMatches = candidateElementMatches.stream() + .filter(elementMatch -> !(SchemaElementType.TAG.equals(elementMatch.getElement().getType()) + || SchemaElementType.TAG_VALUE.equals(elementMatch.getElement().getType()))) + .collect(Collectors.toList()); + } + return candidateElementMatches; + } + protected QueryStructReq convertQueryStruct() { return QueryReqBuilder.buildStructReq(parseInfo); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java index 0c03a0e19..0c950aafc 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java @@ -171,7 +171,11 @@ public class DataSetServiceImpl dataSetResp.setAdminOrgs(StringUtils.isBlank(dataSetDO.getAdminOrg()) ? Lists.newArrayList() : Arrays.asList(dataSetDO.getAdminOrg().split(","))); dataSetResp.setTypeEnum(TypeEnums.DATASET); - dataSetResp.setQueryType(QueryType.valueOf(dataSetDO.getQueryType())); + String queryType = dataSetDO.getQueryType(); + if (Objects.isNull(queryType)) { + queryType = QueryType.METRIC.name(); + } + dataSetResp.setQueryType(QueryType.valueOf(queryType)); return dataSetResp; }