(improvement)(Chat) add tag filter in parse (#813)

This commit is contained in:
lexluo09
2024-03-13 11:25:50 +08:00
committed by GitHub
parent dfd25f7983
commit 82f56109ac
5 changed files with 57 additions and 12 deletions

View File

@@ -41,11 +41,11 @@ public abstract class BaseMapper implements SchemaMapper {
public abstract void doMap(QueryContext queryContext);
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getDataSetElementMatches();
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
List<SchemaElementMatch> schemaElementMatches = dataSetElementMatches.putIfAbsent(dataSetId, new ArrayList<>());
if (schemaElementMatches == null) {
schemaElementMatches = modelElementMatches.get(modelId);
schemaElementMatches = dataSetElementMatches.get(dataSetId);
}
//remove duplication
AtomicBoolean needAddNew = new AtomicBoolean(true);

View File

@@ -67,7 +67,7 @@ public class ContextInheritParser implements SemanticParser {
}
elementMatches.addAll(matchesToInherit);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(queryContext, chatContext);
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), queryContext)) {

View File

@@ -1,14 +1,17 @@
package com.tencent.supersonic.headless.core.chat.parser.rule;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
/**
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
@@ -29,13 +32,27 @@ public class RuleSqlParser implements SemanticParser {
// iterate all schemaElementMatches to resolve query mode
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(queryContext, chatContext);
queryContext.getCandidateQueries().add(query);
SemanticParseInfo parseInfo = query.getParseInfo();
QueryType queryType = queryContext.getQueryType(parseInfo.getDataSetId());
if (isRightQuery(parseInfo, queryType)) {
queryContext.getCandidateQueries().add(query);
}
}
}
auxiliaryParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
}
private boolean isRightQuery(SemanticParseInfo parseInfo, QueryType queryType) {
if (QueryType.TAG.equals(queryType) && QueryManager.isTagQuery(parseInfo.getQueryMode())) {
return true;
}
if (QueryType.METRIC.equals(queryType) && QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
return true;
}
return false;
}
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.query.rule;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
@@ -96,7 +97,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
Set<Long> dataSetIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
.map(SchemaElement::getDataSet).collect(Collectors.toSet());
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetIds.iterator().next()));
Long dataSetId = dataSetIds.iterator().next();
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
Map<Long, List<SchemaElementMatch>> tag2Values = new HashMap<>();
@@ -114,6 +116,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
case VALUE:
addToValues(semanticSchema, SchemaElementType.DIMENSION, dim2Values, schemaMatch);
break;
case TAG:
case DIMENSION:
parseInfo.getDimensions().add(element);
break;
@@ -214,9 +217,10 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
this.parseInfo = parseInfo;
}
public static List<RuleSemanticQuery> resolve(List<SchemaElementMatch> candidateElementMatches,
public static List<RuleSemanticQuery> resolve(Long dataSetId, List<SchemaElementMatch> candidateElementMatches,
QueryContext queryContext) {
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
candidateElementMatches = filterByQueryType(dataSetId, candidateElementMatches, queryContext);
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
List<SchemaElementMatch> matches = semanticQuery.match(candidateElementMatches, queryContext);
@@ -231,6 +235,26 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
return matchedQueries;
}
private static List<SchemaElementMatch> filterByQueryType(Long dataSetId,
List<SchemaElementMatch> candidateElementMatches, QueryContext queryContext) {
QueryType queryType = queryContext.getQueryType(dataSetId);
if (QueryType.TAG.equals(queryType)) {
candidateElementMatches = candidateElementMatches.stream()
.filter(elementMatch -> !(SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType())
|| SchemaElementType.VALUE.equals(elementMatch.getElement().getType()))
)
.collect(Collectors.toList());
}
if (QueryType.METRIC.equals(queryType)) {
candidateElementMatches = candidateElementMatches.stream()
.filter(elementMatch -> !(SchemaElementType.TAG.equals(elementMatch.getElement().getType())
|| SchemaElementType.TAG_VALUE.equals(elementMatch.getElement().getType())))
.collect(Collectors.toList());
}
return candidateElementMatches;
}
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}

View File

@@ -171,7 +171,11 @@ public class DataSetServiceImpl
dataSetResp.setAdminOrgs(StringUtils.isBlank(dataSetDO.getAdminOrg())
? Lists.newArrayList() : Arrays.asList(dataSetDO.getAdminOrg().split(",")));
dataSetResp.setTypeEnum(TypeEnums.DATASET);
dataSetResp.setQueryType(QueryType.valueOf(dataSetDO.getQueryType()));
String queryType = dataSetDO.getQueryType();
if (Objects.isNull(queryType)) {
queryType = QueryType.METRIC.name();
}
dataSetResp.setQueryType(QueryType.valueOf(queryType));
return dataSetResp;
}