mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(headless) Headless integration embedding functionality, with support for viewId in embeddings. (#725)
This commit is contained in:
@@ -2,12 +2,15 @@ package com.tencent.supersonic.headless.server.listener;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -16,10 +19,6 @@ import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
||||
@@ -38,12 +37,13 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
||||
if (CollectionUtils.isEmpty(event.getDataItems())) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<EmbeddingQuery> embeddingQueries = event.getDataItems()
|
||||
.stream()
|
||||
.map(dataItem -> {
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(
|
||||
dataItem.getId().toString() + DictWordType.NATURE_SPILT
|
||||
dataItem.getId().toString() + Constants.UNDERLINE
|
||||
+ dataItem.getType().name().toLowerCase());
|
||||
embeddingQuery.setQuery(dataItem.getName());
|
||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.headless.server.service;
|
||||
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import java.util.List;
|
||||
|
||||
public interface MetaEmbeddingService {
|
||||
|
||||
List<RetrieveQueryResult> retrieveQuery(List<Long> viewIds, RetrieveQuery retrieveQuery, int num);
|
||||
|
||||
}
|
||||
@@ -19,6 +19,8 @@ public interface ViewService {
|
||||
|
||||
void delete(Long id, User user);
|
||||
|
||||
List<ViewResp> getViewListByCache(MetaFilter metaFilter);
|
||||
|
||||
List<ViewResp> getViews(User user);
|
||||
|
||||
List<ViewResp> getViewsInheritAuth(User user, Long domainId);
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||
import com.tencent.supersonic.common.pojo.DataItem;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
@@ -72,11 +73,11 @@ public class DimensionServiceImpl implements DimensionService {
|
||||
|
||||
|
||||
public DimensionServiceImpl(DimensionRepository dimensionRepository,
|
||||
ModelService modelService,
|
||||
ChatGptHelper chatGptHelper,
|
||||
DatabaseService databaseService,
|
||||
ModelRelaService modelRelaService,
|
||||
ViewService viewService) {
|
||||
ModelService modelService,
|
||||
ChatGptHelper chatGptHelper,
|
||||
DatabaseService databaseService,
|
||||
ModelRelaService modelRelaService,
|
||||
ViewService viewService) {
|
||||
this.modelService = modelService;
|
||||
this.dimensionRepository = dimensionRepository;
|
||||
this.chatGptHelper = chatGptHelper;
|
||||
@@ -129,8 +130,8 @@ public class DimensionServiceImpl implements DimensionService {
|
||||
DimensionConverter.convert(dimensionDO, dimensionReq);
|
||||
dimensionRepository.updateDimension(dimensionDO);
|
||||
if (!oldName.equals(dimensionDO.getName())) {
|
||||
sendEvent(DataItem.builder().modelId(dimensionDO.getModelId()).newName(dimensionReq.getName())
|
||||
.name(oldName).type(TypeEnums.DIMENSION)
|
||||
sendEvent(DataItem.builder().modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
|
||||
.newName(dimensionReq.getName()).name(oldName).type(TypeEnums.DIMENSION)
|
||||
.id(dimensionDO.getId()).build(), EventType.UPDATE);
|
||||
}
|
||||
}
|
||||
@@ -264,7 +265,7 @@ public class DimensionServiceImpl implements DimensionService {
|
||||
}
|
||||
|
||||
private List<DimensionResp> convertList(List<DimensionDO> dimensionDOS,
|
||||
Map<Long, ModelResp> modelRespMap) {
|
||||
Map<Long, ModelResp> modelRespMap) {
|
||||
List<DimensionResp> dimensionResps = Lists.newArrayList();
|
||||
if (!CollectionUtils.isEmpty(dimensionDOS)) {
|
||||
dimensionResps = dimensionDOS.stream()
|
||||
@@ -364,9 +365,9 @@ public class DimensionServiceImpl implements DimensionService {
|
||||
}
|
||||
|
||||
private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
|
||||
List<DataItem> dataItems = dimensionDOS.stream().map(dimensionDO ->
|
||||
DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName())
|
||||
.modelId(dimensionDO.getModelId()).type(TypeEnums.DIMENSION).build())
|
||||
List<DataItem> dataItems = dimensionDOS.stream()
|
||||
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName())
|
||||
.modelId(dimensionDO.getModelId() + Constants.UNDERLINE).type(TypeEnums.DIMENSION).build())
|
||||
.collect(Collectors.toList());
|
||||
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType));
|
||||
}
|
||||
@@ -376,10 +377,4 @@ public class DimensionServiceImpl implements DimensionService {
|
||||
Lists.newArrayList(dataItem), eventType));
|
||||
}
|
||||
|
||||
private DataItem getDataItem(DimensionDO dimensionDO) {
|
||||
return DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName())
|
||||
.bizName(dimensionDO.getBizName())
|
||||
.modelId(dimensionDO.getModelId()).type(TypeEnums.DIMENSION).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package com.tencent.supersonic.headless.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class MetaEmbeddingServiceImpl implements MetaEmbeddingService {
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Autowired
|
||||
private ViewService viewService;
|
||||
|
||||
@Override
|
||||
public List<RetrieveQueryResult> retrieveQuery(List<Long> viewIds, RetrieveQuery retrieveQuery, int num) {
|
||||
// viewIds->modelIds
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setStatus(StatusEnum.ONLINE.getCode());
|
||||
metaFilter.setIds(viewIds);
|
||||
List<ViewResp> viewListByCache = viewService.getViewListByCache(metaFilter);
|
||||
Set<Long> allModels = getModels(viewListByCache);
|
||||
|
||||
Map<Long, List<Long>> modelIdToViewIds = viewListByCache.stream()
|
||||
.flatMap(viewResp -> viewResp.getAllModels().stream()
|
||||
.map(modelId -> Pair.of(modelId, viewResp.getId())))
|
||||
.collect(Collectors.groupingBy(Pair::getLeft, Collectors.mapping(Pair::getRight, Collectors.toList())));
|
||||
|
||||
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", allModels.stream().findFirst().get().toString());
|
||||
retrieveQuery.setFilterCondition(filterCondition);
|
||||
}
|
||||
|
||||
String collectionName = embeddingConfig.getMetaCollectionName();
|
||||
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery, num);
|
||||
if (CollectionUtils.isEmpty(resultList)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
//filter by modelId
|
||||
if (CollectionUtils.isEmpty(allModels)) {
|
||||
return resultList;
|
||||
}
|
||||
return resultList.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isEmpty(retrievals)) {
|
||||
return retrieveQueryResult;
|
||||
}
|
||||
//filter by modelId
|
||||
retrievals.removeIf(retrieval -> {
|
||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||
if (Objects.isNull(modelId)) {
|
||||
return CollectionUtils.isEmpty(allModels);
|
||||
}
|
||||
return !allModels.contains(modelId);
|
||||
});
|
||||
//add viewId
|
||||
retrievals = retrievals.stream().flatMap(retrieval -> {
|
||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||
List<Long> viewIdsByModelId = modelIdToViewIds.get(modelId);
|
||||
if (!CollectionUtils.isEmpty(viewIdsByModelId)) {
|
||||
Set<Retrieval> result = new HashSet<>();
|
||||
for (Long viewId : viewIdsByModelId) {
|
||||
Retrieval retrievalNew = new Retrieval();
|
||||
BeanUtils.copyProperties(retrieval, retrievalNew);
|
||||
retrievalNew.getMetadata().putIfAbsent("viewId", viewId + Constants.UNDERLINE);
|
||||
result.add(retrievalNew);
|
||||
}
|
||||
return result.stream();
|
||||
}
|
||||
Set<Retrieval> result = new HashSet<>();
|
||||
result.add(retrieval);
|
||||
return result.stream();
|
||||
}).collect(Collectors.toList());
|
||||
retrieveQueryResult.setRetrieval(retrievals);
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private Set<Long> getModels(List<ViewResp> viewListByCache) {
|
||||
return viewListByCache.stream()
|
||||
.flatMap(viewResp -> viewResp.getAllModels().stream())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||
import com.tencent.supersonic.common.pojo.DataItem;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
@@ -259,8 +260,8 @@ public class MetricServiceImpl implements MetricService {
|
||||
metricFilter.setModelIds(Lists.newArrayList(modelId));
|
||||
List<MetricResp> metricResps = getMetrics(metricFilter);
|
||||
return metricResps.stream().filter(metricResp ->
|
||||
MetricDefineType.FIELD.equals(metricResp.getMetricDefineType())
|
||||
|| MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType()))
|
||||
MetricDefineType.FIELD.equals(metricResp.getMetricDefineType())
|
||||
|| MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -450,8 +451,8 @@ public class MetricServiceImpl implements MetricService {
|
||||
new HashMap<>(), Lists.newArrayList());
|
||||
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName())
|
||||
.bizName(metricDO.getBizName())
|
||||
.modelId(metricDO.getModelId()).type(TypeEnums.METRIC)
|
||||
.defaultAgg(metricResp.getDefaultAgg()).build();
|
||||
.modelId(metricDO.getModelId() + Constants.UNDERLINE)
|
||||
.type(TypeEnums.METRIC).defaultAgg(metricResp.getDefaultAgg()).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package com.tencent.supersonic.headless.server.service.impl;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
@@ -19,23 +21,26 @@ import com.tencent.supersonic.headless.server.persistence.mapper.ViewDOMapper;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.service.DomainService;
|
||||
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Date;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Service
|
||||
public class ViewServiceImpl
|
||||
extends ServiceImpl<ViewDOMapper, ViewDO> implements ViewService {
|
||||
|
||||
protected final Cache<MetaFilter, List<ViewResp>> viewSchemaCache =
|
||||
CacheBuilder.newBuilder().expireAfterWrite(30, TimeUnit.SECONDS).build();
|
||||
|
||||
@Autowired
|
||||
private DomainService domainService;
|
||||
|
||||
@@ -153,4 +158,14 @@ public class ViewServiceImpl
|
||||
return admins.contains(userName) || viewResp.getCreatedBy().equals(userName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ViewResp> getViewListByCache(MetaFilter metaFilter) {
|
||||
List<ViewResp> viewList = viewSchemaCache.getIfPresent(metaFilter);
|
||||
if (CollectionUtils.isEmpty(viewList)) {
|
||||
viewList = getViewList(metaFilter);
|
||||
viewSchemaCache.put(metaFilter, viewList);
|
||||
}
|
||||
return viewList;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user