(improvement)(semantic) support metric relate dimension setting (#229)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-10-16 16:43:09 +08:00
committed by GitHub
parent 86bf40c8fb
commit f2e8207245
38 changed files with 508 additions and 87 deletions

View File

@@ -47,9 +47,9 @@ public interface SemanticInterpreter {
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd, User user);
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
List<DomainResp> getDomainList(User user);

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.api.pojo;
import lombok.Data;
@Data
public class RelateSchemaElement {
private Long dimensionId;
private boolean isNecessary;
}

View File

@@ -22,10 +22,9 @@ public class SchemaElement implements Serializable {
private String bizName;
private Long useCnt;
private SchemaElementType type;
private List<String> alias;
private List<SchemaValueMap> schemaValueMaps;
private List<RelateSchemaElement> relateSchemaElements;
private String defaultAgg;

View File

@@ -16,6 +16,6 @@ public class QueryDataReq {
private Set<QueryFilter> dimensionFilters = new HashSet<>();
private Set<QueryFilter> metricFilters = new HashSet<>();
private DateConf dateInfo;
private Long queryId = 7L;
private Integer parseId = 2;
private Long queryId;
private Integer parseId;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
@Data
public class RecommendReq {
private Long modelId;
private Long metricId;
}

View File

@@ -4,9 +4,13 @@ import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
@@ -24,7 +28,11 @@ import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaR
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -66,7 +74,6 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
for (SchemaElementMatch schemaElementMatch : candidateElementMatches) {
SchemaElementType type = schemaElementMatch.getElement().getType();
if (SchemaElementType.DIMENSION.equals(type) || SchemaElementType.VALUE.equals(type)) {
if (!blackDimIdList.contains(schemaElementMatch.getElement().getId())) {
filteredMatches.add(schemaElementMatch);
@@ -79,9 +86,54 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
filteredMatches.add(schemaElementMatch);
}
}
filteredMatches = metricRelateDimensionCheck(filteredMatches, modelId);
return filteredMatches;
}
private List<SchemaElementMatch> metricRelateDimensionCheck(List<SchemaElementMatch> elementMatches, Long modelId) {
List<SchemaElementMatch> filterSchemaElementMatch = Lists.newArrayList();
ModelSchema modelSchema = semanticInterpreter.getModelSchema(modelId, true);
Set<SchemaElement> metricElements = modelSchema.getMetrics();
Map<Long, SchemaElementMatch> valueElementMatchMap = elementMatches.stream()
.filter(elementMatch ->
SchemaElementType.VALUE.equals(elementMatch.getElement().getType())
|| SchemaElementType.ID.equals(elementMatch.getElement().getType()))
.collect(Collectors.toMap(elementMatch -> elementMatch.getElement().getId(), e -> e, (e1, e2) -> e1));
Map<Long, SchemaElement> metricMap = metricElements.stream()
.collect(Collectors.toMap(SchemaElement::getId, e -> e, (e1, e2) -> e2));
for (SchemaElementMatch schemaElementMatch : elementMatches) {
if (!SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType())) {
filterSchemaElementMatch.add(schemaElementMatch);
continue;
}
SchemaElement metric = metricMap.get(schemaElementMatch.getElement().getId());
if (metric == null) {
continue;
}
List<RelateSchemaElement> relateSchemaElements = metric.getRelateSchemaElements();
if (CollectionUtils.isEmpty(relateSchemaElements)) {
filterSchemaElementMatch.add(schemaElementMatch);
continue;
}
List<Long> necessaryDimensionIds = relateSchemaElements.stream()
.filter(RelateSchemaElement::isNecessary).map(RelateSchemaElement::getDimensionId)
.collect(Collectors.toList());
boolean flag = true;
for (Long necessaryDimensionId : necessaryDimensionIds) {
if (!valueElementMatchMap.containsKey(necessaryDimensionId)) {
flag = false;
break;
}
}
if (flag) {
filterSchemaElementMatch.add(schemaElementMatch);
}
}
return filterSchemaElementMatch;
}
@Override
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);

View File

@@ -1,21 +1,15 @@
package com.tencent.supersonic.chat.rest;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import com.tencent.supersonic.chat.service.RecommendService;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestParam;
import java.util.List;
/**
@@ -30,31 +24,25 @@ public class RecommendController {
@GetMapping("recommend/{modelId}")
public RecommendResp recommend(@PathVariable("modelId") Long modelId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
QueryReq queryCtx = new QueryReq();
queryCtx.setUser(UserHolder.findUser(request, response));
queryCtx.setModelId(modelId);
return recommendService.recommend(queryCtx, limit);
@RequestParam(value = "limit", required = false) Long limit) {
RecommendReq recommendReq = new RecommendReq();
recommendReq.setModelId(modelId);
return recommendService.recommend(recommendReq, limit);
}
@GetMapping("recommend/metric/{modelId}")
public RecommendResp recommendMetricMode(@PathVariable("modelId") Long modelId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
QueryReq queryCtx = new QueryReq();
queryCtx.setUser(UserHolder.findUser(request, response));
queryCtx.setModelId(modelId);
return recommendService.recommendMetricMode(queryCtx, limit);
@RequestParam(value = "metric", required = false) Long metricId,
@RequestParam(value = "limit", required = false) Long limit) {
RecommendReq recommendReq = new RecommendReq();
recommendReq.setModelId(modelId);
recommendReq.setMetricId(metricId);
return recommendService.recommendMetricMode(recommendReq, limit);
}
@GetMapping("recommend/question")
public List<RecommendQuestionResp> recommendQuestion(
@RequestParam(value = "modelId", required = false) Long modelId,
HttpServletRequest request,
HttpServletResponse response) {
@RequestParam(value = "modelId", required = false) Long modelId) {
return recommendService.recommendQuestion(modelId);
}
}

View File

@@ -1,10 +1,8 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import java.util.List;
/***
@@ -12,9 +10,9 @@ import java.util.List;
*/
public interface RecommendService {
RecommendResp recommend(QueryReq queryCtx, Long limit);
RecommendResp recommend(RecommendReq recommendReq, Long limit);
RecommendResp recommendMetricMode(QueryReq queryCtx, Long limit);
RecommendResp recommendMetricMode(RecommendReq recommendReq, Long limit);
List<RecommendQuestionResp> recommendQuestion(Long modelId);
}

View File

@@ -2,8 +2,9 @@ package com.tencent.supersonic.chat.service.impl;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
@@ -14,12 +15,15 @@ import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.RecommendService;
import com.tencent.supersonic.chat.service.SemanticService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -37,20 +41,40 @@ public class RecommendServiceImpl implements RecommendService {
private SemanticService semanticService;
@Override
public RecommendResp recommend(QueryReq queryCtx, Long limit) {
public RecommendResp recommend(RecommendReq recommendReq, Long limit) {
if (Objects.isNull(limit) || limit <= 0) {
limit = Long.MAX_VALUE;
}
log.debug("limit:{}", limit);
Long modelId = queryCtx.getModelId();
Long modelId = recommendReq.getModelId();
if (Objects.isNull(modelId)) {
return new RecommendResp();
}
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
List<Long> drillDownDimensions = Lists.newArrayList();
Set<SchemaElement> metricElements = modelSchema.getMetrics();
if (recommendReq.getMetricId() != null && !CollectionUtils.isEmpty(metricElements)) {
Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement ->
recommendReq.getMetricId().equals(schemaElement.getId())
&& !CollectionUtils.isEmpty(schemaElement.getRelateSchemaElements()))
.findFirst();
if (metric.isPresent()) {
drillDownDimensions = metric.get().getRelateSchemaElements().stream()
.map(RelateSchemaElement::getDimensionId).collect(Collectors.toList());
}
}
final List<Long> drillDownDimensionsFinal = drillDownDimensions;
List<SchemaElement> dimensions = modelSchema.getDimensions().stream()
.filter(dim -> Objects.nonNull(dim) && Objects.nonNull(dim.getUseCnt()))
.filter(dim -> {
if (Objects.isNull(dim)) {
return false;
}
if (!CollectionUtils.isEmpty(drillDownDimensionsFinal)) {
return drillDownDimensionsFinal.contains(dim.getId());
} else {
return Objects.nonNull(dim.getUseCnt());
}
})
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(limit)
.map(dimSchemaDesc -> {
@@ -84,14 +108,14 @@ public class RecommendServiceImpl implements RecommendService {
}
@Override
public RecommendResp recommendMetricMode(QueryReq queryCtx, Long limit) {
RecommendResp recommendResponse = recommend(queryCtx, limit);
public RecommendResp recommendMetricMode(RecommendReq recommendReq, Long limit) {
RecommendResp recommendResponse = recommend(recommendReq, limit);
// filter black Item
if (Objects.isNull(recommendResponse)) {
return recommendResponse;
}
ChatConfigRichResp chatConfigRich = configService.getConfigRichInfo(Long.valueOf(queryCtx.getModelId()));
ChatConfigRichResp chatConfigRich = configService.getConfigRichInfo(recommendReq.getModelId());
if (Objects.nonNull(chatConfigRich) && Objects.nonNull(chatConfigRich.getChatAggRichConfig())
&& Objects.nonNull(chatConfigRich.getChatAggRichConfig().getVisibility())) {
List<Long> blackMetricIdList = chatConfigRich.getChatAggRichConfig().getVisibility().getBlackMetricIdList();

View File

@@ -17,7 +17,7 @@ import org.springframework.util.CollectionUtils;
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
protected final Cache<String, List<ModelSchemaResp>> modelSchemaCache =
CacheBuilder.newBuilder().expireAfterWrite(60, TimeUnit.SECONDS).build();
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
@SneakyThrows
public List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable) {
@@ -33,13 +33,13 @@ public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
}
@Override
public ModelSchema getModelSchema(Long domain, Boolean cacheEnable) {
public ModelSchema getModelSchema(Long model, Boolean cacheEnable) {
List<Long> ids = new ArrayList<>();
ids.add(domain);
ids.add(model);
List<ModelSchemaResp> modelSchemaResps = fetchModelSchema(ids, cacheEnable);
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
Optional<ModelSchemaResp> modelSchemaResp = modelSchemaResps.stream()
.filter(d -> d.getId().equals(domain)).findFirst();
.filter(d -> d.getId().equals(model)).findFirst();
if (modelSchemaResp.isPresent()) {
ModelSchemaResp modelSchema = modelSchemaResp.get();
return ModelSchemaBuilder.build(modelSchema);

View File

@@ -1,11 +1,14 @@
package com.tencent.supersonic.knowledge.semantic;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.semantic.api.model.pojo.DimValueMap;
import com.tencent.supersonic.semantic.api.model.pojo.Entity;
import com.tencent.supersonic.semantic.api.model.pojo.RelateDimension;
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
@@ -29,7 +32,7 @@ public class ModelSchemaBuilder {
public static ModelSchema build(ModelSchemaResp resp) {
ModelSchema domainSchema = new ModelSchema();
ModelSchema modelSchema = new ModelSchema();
SchemaElement domain = SchemaElement.builder()
.model(resp.getId())
.id(resp.getId())
@@ -38,7 +41,7 @@ public class ModelSchemaBuilder {
.type(SchemaElementType.MODEL)
.alias(getAliasList(resp.getAlias()))
.build();
domainSchema.setModel(domain);
modelSchema.setModel(domain);
Set<SchemaElement> metrics = new HashSet<>();
for (MetricSchemaResp metric : resp.getMetrics()) {
@@ -53,12 +56,13 @@ public class ModelSchemaBuilder {
.type(SchemaElementType.METRIC)
.useCnt(metric.getUseCnt())
.alias(alias)
.relateSchemaElements(getRelateSchemaElement(metric))
.defaultAgg(metric.getDefaultAgg())
.build();
metrics.add(metricToAdd);
}
domainSchema.getMetrics().addAll(metrics);
modelSchema.getMetrics().addAll(metrics);
Set<SchemaElement> dimensions = new HashSet<>();
Set<SchemaElement> dimensionValues = new HashSet<>();
@@ -106,8 +110,8 @@ public class ModelSchemaBuilder {
.build();
dimensionValues.add(dimValueToAdd);
}
domainSchema.getDimensions().addAll(dimensions);
domainSchema.getDimensionValues().addAll(dimensionValues);
modelSchema.getDimensions().addAll(dimensions);
modelSchema.getDimensionValues().addAll(dimensionValues);
Entity entity = resp.getEntity();
if (Objects.nonNull(entity)) {
@@ -122,11 +126,11 @@ public class ModelSchemaBuilder {
entityElement.setType(SchemaElementType.ENTITY);
}
entityElement.setAlias(entity.getNames());
domainSchema.setEntity(entityElement);
modelSchema.setEntity(entityElement);
}
}
return domainSchema;
return modelSchema;
}
private static List<String> getAliasList(String alias) {
@@ -136,4 +140,16 @@ public class ModelSchemaBuilder {
return Arrays.asList(alias.split(aliasSplit));
}
private static List<RelateSchemaElement> getRelateSchemaElement(MetricSchemaResp metricSchemaResp) {
RelateDimension relateDimension = metricSchemaResp.getRelateDimension();
if (relateDimension == null || CollectionUtils.isEmpty(relateDimension.getDrillDownDimensions())) {
return Lists.newArrayList();
}
return relateDimension.getDrillDownDimensions().stream().map(dimension -> {
RelateSchemaElement relateSchemaElement = new RelateSchemaElement();
BeanUtils.copyProperties(dimension, relateSchemaElement);
return relateSchemaElement;
}).collect(Collectors.toList());
}
}