mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(Headless) support semantic search in metric market (#934)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -65,6 +65,8 @@ public class MetricResp extends SchemaItem {
|
||||
|
||||
private Integer isPublish;
|
||||
|
||||
private double similarity;
|
||||
|
||||
public void setClassifications(String tag) {
|
||||
if (StringUtils.isBlank(tag)) {
|
||||
classifications = Lists.newArrayList();
|
||||
|
||||
@@ -107,7 +107,7 @@ public class MetricController {
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return metricService.queryMetric(pageMetricReq, user);
|
||||
return metricService.queryMetricMarket(pageMetricReq, user);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.server.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.CollectDO;
|
||||
import java.util.List;
|
||||
|
||||
@@ -18,4 +19,5 @@ public interface CollectService {
|
||||
|
||||
List<CollectDO> getCollectList(String username);
|
||||
|
||||
List<CollectDO> getCollectList(String username, TypeEnums typeEnums);
|
||||
}
|
||||
|
||||
@@ -35,6 +35,8 @@ public interface MetricService {
|
||||
|
||||
void deleteMetric(Long id, User user) throws Exception;
|
||||
|
||||
PageInfo<MetricResp> queryMetricMarket(PageMetricReq pageMetricReq, User user);
|
||||
|
||||
PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user);
|
||||
|
||||
List<MetricResp> getMetrics(MetaFilter metaFilter);
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.CollectDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.mapper.CollectMapper;
|
||||
import com.tencent.supersonic.headless.server.service.CollectService;
|
||||
@@ -60,4 +61,15 @@ public class CollectServiceImpl implements CollectService {
|
||||
}
|
||||
return collectMapper.selectList(queryWrapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<CollectDO> getCollectList(String username, TypeEnums typeEnums) {
|
||||
QueryWrapper<CollectDO> queryWrapper = new QueryWrapper<>();
|
||||
if (!StringUtils.isEmpty(username)) {
|
||||
queryWrapper.lambda().eq(CollectDO::getUsername, username);
|
||||
}
|
||||
queryWrapper.lambda().eq(CollectDO::getType, typeEnums.name().toLowerCase());
|
||||
return collectMapper.selectList(queryWrapper);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,17 +22,22 @@ import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
|
||||
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
|
||||
import com.tencent.supersonic.headless.api.pojo.MetricParam;
|
||||
import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetricBaseReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.TagItem;
|
||||
@@ -51,6 +56,7 @@ import com.tencent.supersonic.headless.server.pojo.TagFilter;
|
||||
import com.tencent.supersonic.headless.server.service.CollectService;
|
||||
import com.tencent.supersonic.headless.server.service.DataSetService;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.MetaDiscoveryService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.service.TagMetaService;
|
||||
@@ -61,10 +67,12 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.Date;
|
||||
@@ -96,6 +104,8 @@ public class MetricServiceImpl implements MetricService {
|
||||
|
||||
private TagMetaService tagMetaService;
|
||||
|
||||
private MetaDiscoveryService metaDiscoveryService;
|
||||
|
||||
public MetricServiceImpl(MetricRepository metricRepository,
|
||||
ModelService modelService,
|
||||
ChatGptHelper chatGptHelper,
|
||||
@@ -103,7 +113,8 @@ public class MetricServiceImpl implements MetricService {
|
||||
DataSetService dataSetService,
|
||||
ApplicationEventPublisher eventPublisher,
|
||||
DimensionService dimensionService,
|
||||
TagMetaService tagMetaService) {
|
||||
TagMetaService tagMetaService,
|
||||
@Lazy MetaDiscoveryService metaDiscoveryService) {
|
||||
this.metricRepository = metricRepository;
|
||||
this.modelService = modelService;
|
||||
this.chatGptHelper = chatGptHelper;
|
||||
@@ -112,6 +123,7 @@ public class MetricServiceImpl implements MetricService {
|
||||
this.dataSetService = dataSetService;
|
||||
this.dimensionService = dimensionService;
|
||||
this.tagMetaService = tagMetaService;
|
||||
this.metaDiscoveryService = metaDiscoveryService;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -230,24 +242,57 @@ public class MetricServiceImpl implements MetricService {
|
||||
sendEventBatch(Lists.newArrayList(metricDO), EventType.DELETE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<MetricResp> queryMetricMarket(PageMetricReq pageMetricReq, User user) {
|
||||
//search by whole text
|
||||
PageInfo<MetricResp> metricRespPageInfo = queryMetric(pageMetricReq, user);
|
||||
if (metricRespPageInfo.hasContent() || StringUtils.isBlank(pageMetricReq.getKey())) {
|
||||
return metricRespPageInfo;
|
||||
}
|
||||
//search by text split
|
||||
QueryMapReq queryMapReq = new QueryMapReq();
|
||||
queryMapReq.setQueryText(pageMetricReq.getKey());
|
||||
queryMapReq.setUser(user);
|
||||
queryMapReq.setMapModeEnum(MapModeEnum.MODERATE);
|
||||
MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq);
|
||||
Map<String, List<SchemaElementMatch>> mapFields = mapMeta.getMapFields();
|
||||
if (CollectionUtils.isEmpty(mapFields)) {
|
||||
return metricRespPageInfo;
|
||||
}
|
||||
Map<Long, Double> result = mapFields.values().stream()
|
||||
.flatMap(Collection::stream).filter(schemaElementMatch ->
|
||||
SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toMap(schemaElementMatch ->
|
||||
schemaElementMatch.getElement().getId(), SchemaElementMatch::getSimilarity,
|
||||
(existingValue, newValue) -> existingValue));
|
||||
List<Long> metricIds = new ArrayList<>(result.keySet());
|
||||
if (CollectionUtils.isEmpty(result.keySet())) {
|
||||
return metricRespPageInfo;
|
||||
}
|
||||
pageMetricReq.setIds(metricIds);
|
||||
pageMetricReq.setKey("");
|
||||
PageInfo<MetricResp> metricPage = queryMetric(pageMetricReq, user);
|
||||
for (MetricResp metricResp : metricPage.getList()) {
|
||||
metricResp.setSimilarity(result.get(metricResp.getId()));
|
||||
}
|
||||
metricPage.getList().sort(Comparator.comparingDouble(MetricResp::getSimilarity).reversed());
|
||||
return metricPage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user) {
|
||||
MetricFilter metricFilter = new MetricFilter();
|
||||
metricFilter.setUserName(user.getName());
|
||||
BeanUtils.copyProperties(pageMetricReq, metricFilter);
|
||||
List<ModelResp> modelResps = modelService.getAllModelByDomainIds(pageMetricReq.getDomainIds());
|
||||
List<Long> modelIds = modelResps.stream().map(ModelResp::getId).collect(Collectors.toList());
|
||||
pageMetricReq.getModelIds().addAll(modelIds);
|
||||
metricFilter.setModelIds(pageMetricReq.getModelIds());
|
||||
List<CollectDO> collectList = collectService.getCollectList(user.getName());
|
||||
List<Long> collectIds = collectList.stream().map(CollectDO::getCollectId).collect(Collectors.toList());
|
||||
if (pageMetricReq.isHasCollect()) {
|
||||
if (CollectionUtils.isEmpty(collectIds)) {
|
||||
metricFilter.setIds(Lists.newArrayList(-1L));
|
||||
} else {
|
||||
metricFilter.setIds(collectIds);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(pageMetricReq.getDomainIds())) {
|
||||
List<ModelResp> modelResps = modelService.getAllModelByDomainIds(pageMetricReq.getDomainIds());
|
||||
List<Long> modelIds = modelResps.stream().map(ModelResp::getId).collect(Collectors.toList());
|
||||
pageMetricReq.getModelIds().addAll(modelIds);
|
||||
}
|
||||
metricFilter.setModelIds(pageMetricReq.getModelIds());
|
||||
List<Long> collectIds = getCollectIds(pageMetricReq, user);
|
||||
List<Long> idsToFilter = getIdsToFilter(pageMetricReq, collectIds);
|
||||
metricFilter.setIds(idsToFilter);
|
||||
PageInfo<MetricDO> metricDOPageInfo = PageHelper.startPage(pageMetricReq.getCurrent(),
|
||||
pageMetricReq.getPageSize())
|
||||
.doSelectPageInfo(() -> queryMetric(metricFilter));
|
||||
@@ -293,6 +338,32 @@ public class MetricServiceImpl implements MetricService {
|
||||
return metricResps;
|
||||
}
|
||||
|
||||
private List<Long> getCollectIds(PageMetricReq pageMetricReq, User user) {
|
||||
List<CollectDO> collectList = collectService.getCollectList(user.getName(), TypeEnums.METRIC);
|
||||
List<Long> collectIds = collectList.stream().map(CollectDO::getCollectId).collect(Collectors.toList());
|
||||
if (pageMetricReq.isHasCollect()) {
|
||||
if (CollectionUtils.isEmpty(collectIds)) {
|
||||
return Lists.newArrayList(-1L);
|
||||
} else {
|
||||
return collectIds;
|
||||
}
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
private List<Long> getIdsToFilter(PageMetricReq pageMetricReq, List<Long> collectIds) {
|
||||
if (CollectionUtils.isEmpty(pageMetricReq.getIds())) {
|
||||
return collectIds;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(collectIds)) {
|
||||
return pageMetricReq.getIds();
|
||||
}
|
||||
List<Long> idsToFilter = new ArrayList<>(collectIds);
|
||||
idsToFilter.retainAll(pageMetricReq.getIds());
|
||||
idsToFilter.add(-1L);
|
||||
return idsToFilter;
|
||||
}
|
||||
|
||||
private void fillTagInfo(List<MetricResp> metricRespList) {
|
||||
if (CollectionUtils.isEmpty(metricRespList)) {
|
||||
return;
|
||||
@@ -405,7 +476,7 @@ public class MetricServiceImpl implements MetricService {
|
||||
ModelFilter modelFilter = new ModelFilter(false,
|
||||
Lists.newArrayList(metricDO.getModelId()));
|
||||
Map<Long, ModelResp> modelMap = modelService.getModelMap(modelFilter);
|
||||
List<CollectDO> collectList = collectService.getCollectList(user.getName());
|
||||
List<CollectDO> collectList = collectService.getCollectList(user.getName(), TypeEnums.METRIC);
|
||||
List<Long> collect = collectList.stream().map(CollectDO::getCollectId).collect(Collectors.toList());
|
||||
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO, modelMap, collect);
|
||||
fillAdminRes(Lists.newArrayList(metricResp), user);
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
package com.tencent.supersonic.headless.server.service;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.DataFormat;
|
||||
@@ -25,12 +22,16 @@ import com.tencent.supersonic.headless.server.persistence.repository.MetricRepos
|
||||
import com.tencent.supersonic.headless.server.service.impl.DataSetServiceImpl;
|
||||
import com.tencent.supersonic.headless.server.service.impl.MetricServiceImpl;
|
||||
import com.tencent.supersonic.headless.server.utils.MetricConverter;
|
||||
import java.util.HashMap;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class MetricServiceImplTest {
|
||||
|
||||
@Test
|
||||
@@ -69,8 +70,9 @@ public class MetricServiceImplTest {
|
||||
DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class);
|
||||
DimensionService dimensionService = Mockito.mock(DimensionService.class);
|
||||
TagMetaService tagMetaService = Mockito.mock(TagMetaService.class);
|
||||
MetaDiscoveryService metaDiscoveryService = Mockito.mock(MetaDiscoveryService.class);
|
||||
return new MetricServiceImpl(metricRepository, modelService, chatGptHelper, collectService, dataSetService,
|
||||
eventPublisher, dimensionService, tagMetaService);
|
||||
eventPublisher, dimensionService, tagMetaService, metaDiscoveryService);
|
||||
}
|
||||
|
||||
private MetricReq buildMetricReq() {
|
||||
|
||||
Reference in New Issue
Block a user