mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(feature)add metric check parser in chat and add metric check convert in semantic, download metric data in semantic (#241)
* (improvement)(chat) add metric check parser * (improvement)(semantic) support metric data download --------- Co-authored-by: jolunoluo
This commit is contained in:
@@ -0,0 +1,100 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class MetricCheckParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
||||
if (CollectionUtils.isEmpty(semanticQueries)) {
|
||||
return;
|
||||
}
|
||||
semanticQueries.removeIf(this::removeQuery);
|
||||
}
|
||||
|
||||
private boolean removeQuery(SemanticQuery semanticQuery) {
|
||||
if (semanticQuery instanceof MetricSemanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
List<SchemaElementMatch> schemaElementMatches = parseInfo.getElementMatches();
|
||||
List<SchemaElementMatch> elementMatchFiltered =
|
||||
filterMetricElement(schemaElementMatches, parseInfo.getModelId());
|
||||
return 0 >= getMetricElementMatchCount(elementMatchFiltered);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private List<SchemaElementMatch> filterMetricElement(List<SchemaElementMatch> elementMatches, Long modelId) {
|
||||
List<SchemaElementMatch> filterSchemaElementMatch = Lists.newArrayList();
|
||||
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
ModelSchema modelSchema = semanticInterpreter.getModelSchema(modelId, true);
|
||||
Set<SchemaElement> metricElements = modelSchema.getMetrics();
|
||||
Map<Long, SchemaElementMatch> valueElementMatchMap = getValueElementMap(elementMatches);
|
||||
Map<Long, SchemaElement> metricMap = metricElements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, e -> e, (e1, e2) -> e2));
|
||||
for (SchemaElementMatch schemaElementMatch : elementMatches) {
|
||||
if (!SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType())) {
|
||||
filterSchemaElementMatch.add(schemaElementMatch);
|
||||
continue;
|
||||
}
|
||||
SchemaElement metric = metricMap.get(schemaElementMatch.getElement().getId());
|
||||
List<Long> necessaryDimensionIds = getNecessaryDimensionIds(metric);
|
||||
boolean flag = true;
|
||||
for (Long necessaryDimensionId : necessaryDimensionIds) {
|
||||
if (!valueElementMatchMap.containsKey(necessaryDimensionId)) {
|
||||
flag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (flag) {
|
||||
filterSchemaElementMatch.add(schemaElementMatch);
|
||||
}
|
||||
}
|
||||
return filterSchemaElementMatch;
|
||||
}
|
||||
|
||||
private Map<Long, SchemaElementMatch> getValueElementMap(List<SchemaElementMatch> elementMatches) {
|
||||
return elementMatches.stream()
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.VALUE.equals(elementMatch.getElement().getType()))
|
||||
.collect(Collectors.toMap(elementMatch -> elementMatch.getElement().getId(), e -> e, (e1, e2) -> e1));
|
||||
}
|
||||
|
||||
private long getMetricElementMatchCount(List<SchemaElementMatch> elementMatches) {
|
||||
return elementMatches.stream().filter(elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType()))
|
||||
.count();
|
||||
}
|
||||
|
||||
private List<Long> getNecessaryDimensionIds(SchemaElement metric) {
|
||||
if (metric == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<RelateSchemaElement> relateSchemaElements = metric.getRelateSchemaElements();
|
||||
if (CollectionUtils.isEmpty(relateSchemaElements)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return relateSchemaElements.stream()
|
||||
.filter(RelateSchemaElement::isNecessary).map(RelateSchemaElement::getDimensionId)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,14 +3,9 @@ package com.tencent.supersonic.chat.query.rule.metric;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
@@ -28,11 +23,7 @@ import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaR
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@@ -86,53 +77,9 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
filteredMatches.add(schemaElementMatch);
|
||||
}
|
||||
}
|
||||
filteredMatches = metricRelateDimensionCheck(filteredMatches, modelId);
|
||||
return filteredMatches;
|
||||
}
|
||||
|
||||
private List<SchemaElementMatch> metricRelateDimensionCheck(List<SchemaElementMatch> elementMatches, Long modelId) {
|
||||
List<SchemaElementMatch> filterSchemaElementMatch = Lists.newArrayList();
|
||||
|
||||
ModelSchema modelSchema = semanticInterpreter.getModelSchema(modelId, true);
|
||||
Set<SchemaElement> metricElements = modelSchema.getMetrics();
|
||||
Map<Long, SchemaElementMatch> valueElementMatchMap = elementMatches.stream()
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.VALUE.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(elementMatch.getElement().getType()))
|
||||
.collect(Collectors.toMap(elementMatch -> elementMatch.getElement().getId(), e -> e, (e1, e2) -> e1));
|
||||
Map<Long, SchemaElement> metricMap = metricElements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, e -> e, (e1, e2) -> e2));
|
||||
|
||||
for (SchemaElementMatch schemaElementMatch : elementMatches) {
|
||||
if (!SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType())) {
|
||||
filterSchemaElementMatch.add(schemaElementMatch);
|
||||
continue;
|
||||
}
|
||||
SchemaElement metric = metricMap.get(schemaElementMatch.getElement().getId());
|
||||
if (metric == null) {
|
||||
continue;
|
||||
}
|
||||
List<RelateSchemaElement> relateSchemaElements = metric.getRelateSchemaElements();
|
||||
if (CollectionUtils.isEmpty(relateSchemaElements)) {
|
||||
filterSchemaElementMatch.add(schemaElementMatch);
|
||||
continue;
|
||||
}
|
||||
List<Long> necessaryDimensionIds = relateSchemaElements.stream()
|
||||
.filter(RelateSchemaElement::isNecessary).map(RelateSchemaElement::getDimensionId)
|
||||
.collect(Collectors.toList());
|
||||
boolean flag = true;
|
||||
for (Long necessaryDimensionId : necessaryDimensionIds) {
|
||||
if (!valueElementMatchMap.containsKey(necessaryDimensionId)) {
|
||||
flag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (flag) {
|
||||
filterSchemaElementMatch.add(schemaElementMatch);
|
||||
}
|
||||
}
|
||||
return filterSchemaElementMatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
|
||||
|
||||
@@ -32,7 +32,7 @@ public class RecommendController {
|
||||
|
||||
@GetMapping("recommend/metric/{modelId}")
|
||||
public RecommendResp recommendMetricMode(@PathVariable("modelId") Long modelId,
|
||||
@RequestParam(value = "metric", required = false) Long metricId,
|
||||
@RequestParam(value = "metricId", required = false) Long metricId,
|
||||
@RequestParam(value = "limit", required = false) Long limit) {
|
||||
RecommendReq recommendReq = new RecommendReq();
|
||||
recommendReq.setModelId(modelId);
|
||||
|
||||
Reference in New Issue
Block a user