(improvement)(Headless) support semantic search in metric market (#934)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-04-24 11:02:41 +08:00
committed by GitHub
parent c45800bda6
commit 20d142fca8
7 changed files with 111 additions and 20 deletions

View File

@@ -65,6 +65,8 @@ public class MetricResp extends SchemaItem {
private Integer isPublish;
private double similarity;
public void setClassifications(String tag) {
if (StringUtils.isBlank(tag)) {
classifications = Lists.newArrayList();

View File

@@ -107,7 +107,7 @@ public class MetricController {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return metricService.queryMetric(pageMetricReq, user);
return metricService.queryMetricMarket(pageMetricReq, user);
}
@Deprecated

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.server.persistence.dataobject.CollectDO;
import java.util.List;
@@ -18,4 +19,5 @@ public interface CollectService {
List<CollectDO> getCollectList(String username);
List<CollectDO> getCollectList(String username, TypeEnums typeEnums);
}

View File

@@ -35,6 +35,8 @@ public interface MetricService {
void deleteMetric(Long id, User user) throws Exception;
PageInfo<MetricResp> queryMetricMarket(PageMetricReq pageMetricReq, User user);
PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user);
List<MetricResp> getMetrics(MetaFilter metaFilter);

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.server.persistence.dataobject.CollectDO;
import com.tencent.supersonic.headless.server.persistence.mapper.CollectMapper;
import com.tencent.supersonic.headless.server.service.CollectService;
@@ -60,4 +61,15 @@ public class CollectServiceImpl implements CollectService {
}
return collectMapper.selectList(queryWrapper);
}
@Override
public List<CollectDO> getCollectList(String username, TypeEnums typeEnums) {
QueryWrapper<CollectDO> queryWrapper = new QueryWrapper<>();
if (!StringUtils.isEmpty(username)) {
queryWrapper.lambda().eq(CollectDO::getUsername, username);
}
queryWrapper.lambda().eq(CollectDO::getType, typeEnums.name().toLowerCase());
return collectMapper.selectList(queryWrapper);
}
}

View File

@@ -22,17 +22,22 @@ import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.MetricParam;
import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricBaseReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.TagItem;
@@ -51,6 +56,7 @@ import com.tencent.supersonic.headless.server.pojo.TagFilter;
import com.tencent.supersonic.headless.server.service.CollectService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetaDiscoveryService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.TagMetaService;
@@ -61,10 +67,12 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
@@ -96,6 +104,8 @@ public class MetricServiceImpl implements MetricService {
private TagMetaService tagMetaService;
private MetaDiscoveryService metaDiscoveryService;
public MetricServiceImpl(MetricRepository metricRepository,
ModelService modelService,
ChatGptHelper chatGptHelper,
@@ -103,7 +113,8 @@ public class MetricServiceImpl implements MetricService {
DataSetService dataSetService,
ApplicationEventPublisher eventPublisher,
DimensionService dimensionService,
TagMetaService tagMetaService) {
TagMetaService tagMetaService,
@Lazy MetaDiscoveryService metaDiscoveryService) {
this.metricRepository = metricRepository;
this.modelService = modelService;
this.chatGptHelper = chatGptHelper;
@@ -112,6 +123,7 @@ public class MetricServiceImpl implements MetricService {
this.dataSetService = dataSetService;
this.dimensionService = dimensionService;
this.tagMetaService = tagMetaService;
this.metaDiscoveryService = metaDiscoveryService;
}
@Override
@@ -230,24 +242,57 @@ public class MetricServiceImpl implements MetricService {
sendEventBatch(Lists.newArrayList(metricDO), EventType.DELETE);
}
@Override
public PageInfo<MetricResp> queryMetricMarket(PageMetricReq pageMetricReq, User user) {
//search by whole text
PageInfo<MetricResp> metricRespPageInfo = queryMetric(pageMetricReq, user);
if (metricRespPageInfo.hasContent() || StringUtils.isBlank(pageMetricReq.getKey())) {
return metricRespPageInfo;
}
//search by text split
QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText(pageMetricReq.getKey());
queryMapReq.setUser(user);
queryMapReq.setMapModeEnum(MapModeEnum.MODERATE);
MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq);
Map<String, List<SchemaElementMatch>> mapFields = mapMeta.getMapFields();
if (CollectionUtils.isEmpty(mapFields)) {
return metricRespPageInfo;
}
Map<Long, Double> result = mapFields.values().stream()
.flatMap(Collection::stream).filter(schemaElementMatch ->
SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toMap(schemaElementMatch ->
schemaElementMatch.getElement().getId(), SchemaElementMatch::getSimilarity,
(existingValue, newValue) -> existingValue));
List<Long> metricIds = new ArrayList<>(result.keySet());
if (CollectionUtils.isEmpty(result.keySet())) {
return metricRespPageInfo;
}
pageMetricReq.setIds(metricIds);
pageMetricReq.setKey("");
PageInfo<MetricResp> metricPage = queryMetric(pageMetricReq, user);
for (MetricResp metricResp : metricPage.getList()) {
metricResp.setSimilarity(result.get(metricResp.getId()));
}
metricPage.getList().sort(Comparator.comparingDouble(MetricResp::getSimilarity).reversed());
return metricPage;
}
@Override
public PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user) {
MetricFilter metricFilter = new MetricFilter();
metricFilter.setUserName(user.getName());
BeanUtils.copyProperties(pageMetricReq, metricFilter);
if (!CollectionUtils.isEmpty(pageMetricReq.getDomainIds())) {
List<ModelResp> modelResps = modelService.getAllModelByDomainIds(pageMetricReq.getDomainIds());
List<Long> modelIds = modelResps.stream().map(ModelResp::getId).collect(Collectors.toList());
pageMetricReq.getModelIds().addAll(modelIds);
}
metricFilter.setModelIds(pageMetricReq.getModelIds());
List<CollectDO> collectList = collectService.getCollectList(user.getName());
List<Long> collectIds = collectList.stream().map(CollectDO::getCollectId).collect(Collectors.toList());
if (pageMetricReq.isHasCollect()) {
if (CollectionUtils.isEmpty(collectIds)) {
metricFilter.setIds(Lists.newArrayList(-1L));
} else {
metricFilter.setIds(collectIds);
}
}
List<Long> collectIds = getCollectIds(pageMetricReq, user);
List<Long> idsToFilter = getIdsToFilter(pageMetricReq, collectIds);
metricFilter.setIds(idsToFilter);
PageInfo<MetricDO> metricDOPageInfo = PageHelper.startPage(pageMetricReq.getCurrent(),
pageMetricReq.getPageSize())
.doSelectPageInfo(() -> queryMetric(metricFilter));
@@ -293,6 +338,32 @@ public class MetricServiceImpl implements MetricService {
return metricResps;
}
private List<Long> getCollectIds(PageMetricReq pageMetricReq, User user) {
List<CollectDO> collectList = collectService.getCollectList(user.getName(), TypeEnums.METRIC);
List<Long> collectIds = collectList.stream().map(CollectDO::getCollectId).collect(Collectors.toList());
if (pageMetricReq.isHasCollect()) {
if (CollectionUtils.isEmpty(collectIds)) {
return Lists.newArrayList(-1L);
} else {
return collectIds;
}
}
return Lists.newArrayList();
}
private List<Long> getIdsToFilter(PageMetricReq pageMetricReq, List<Long> collectIds) {
if (CollectionUtils.isEmpty(pageMetricReq.getIds())) {
return collectIds;
}
if (CollectionUtils.isEmpty(collectIds)) {
return pageMetricReq.getIds();
}
List<Long> idsToFilter = new ArrayList<>(collectIds);
idsToFilter.retainAll(pageMetricReq.getIds());
idsToFilter.add(-1L);
return idsToFilter;
}
private void fillTagInfo(List<MetricResp> metricRespList) {
if (CollectionUtils.isEmpty(metricRespList)) {
return;
@@ -405,7 +476,7 @@ public class MetricServiceImpl implements MetricService {
ModelFilter modelFilter = new ModelFilter(false,
Lists.newArrayList(metricDO.getModelId()));
Map<Long, ModelResp> modelMap = modelService.getModelMap(modelFilter);
List<CollectDO> collectList = collectService.getCollectList(user.getName());
List<CollectDO> collectList = collectService.getCollectList(user.getName(), TypeEnums.METRIC);
List<Long> collect = collectList.stream().map(CollectDO::getCollectId).collect(Collectors.toList());
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO, modelMap, collect);
fillAdminRes(Lists.newArrayList(metricResp), user);

View File

@@ -1,8 +1,5 @@
package com.tencent.supersonic.headless.server.service;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DataFormat;
@@ -25,12 +22,16 @@ import com.tencent.supersonic.headless.server.persistence.repository.MetricRepos
import com.tencent.supersonic.headless.server.service.impl.DataSetServiceImpl;
import com.tencent.supersonic.headless.server.service.impl.MetricServiceImpl;
import com.tencent.supersonic.headless.server.utils.MetricConverter;
import java.util.HashMap;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.context.ApplicationEventPublisher;
import java.util.HashMap;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
public class MetricServiceImplTest {
@Test
@@ -69,8 +70,9 @@ public class MetricServiceImplTest {
DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class);
DimensionService dimensionService = Mockito.mock(DimensionService.class);
TagMetaService tagMetaService = Mockito.mock(TagMetaService.class);
MetaDiscoveryService metaDiscoveryService = Mockito.mock(MetaDiscoveryService.class);
return new MetricServiceImpl(metricRepository, modelService, chatGptHelper, collectService, dataSetService,
eventPublisher, dimensionService, tagMetaService);
eventPublisher, dimensionService, tagMetaService, metaDiscoveryService);
}
private MetricReq buildMetricReq() {