diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/BaseMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/BaseMatchStrategy.java index af9a49a58..df23973a8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/BaseMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/BaseMatchStrategy.java @@ -44,7 +44,7 @@ public abstract class BaseMatchStrategy implements MatchStrategy { return result; } - public List detect(QueryContext queryContext, List terms, Set detectModelIds) { + public List detect(QueryContext queryContext, List terms, Set detectViewIds) { Map regOffsetToLength = getRegOffsetToLength(terms); String text = queryContext.getQueryText(); Set results = new HashSet<>(); @@ -59,16 +59,16 @@ public abstract class BaseMatchStrategy implements MatchStrategy { if (index <= text.length()) { String detectSegment = text.substring(startIndex, index); detectSegments.add(detectSegment); - detectByStep(queryContext, results, detectModelIds, startIndex, index, offset); + detectByStep(queryContext, results, detectViewIds, startIndex, index, offset); } } startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); } - detectByBatch(queryContext, results, detectModelIds, detectSegments); + detectByBatch(queryContext, results, detectViewIds, detectSegments); return new ArrayList<>(results); } - protected void detectByBatch(QueryContext queryContext, Set results, Set detectModelIds, + protected void detectByBatch(QueryContext queryContext, Set results, Set detectViewIds, Set detectSegments) { return; } @@ -152,6 +152,6 @@ public abstract class BaseMatchStrategy implements MatchStrategy { public abstract String getMapKey(T a); public abstract void detectByStep(QueryContext queryContext, Set results, - Set detectModelIds, Integer startIndex, Integer index, int offset); + Set detectViewIds, Integer startIndex, Integer index, int offset); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/DatabaseMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/DatabaseMatchStrategy.java index 27726248f..647a5e964 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/DatabaseMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/DatabaseMatchStrategy.java @@ -37,9 +37,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy @Override public Map> match(QueryContext queryContext, List terms, - Set detectModelIds) { + Set detectViewIds) { this.allElements = getSchemaElements(queryContext); - return super.match(queryContext, terms, detectModelIds); + return super.match(queryContext, terms, detectViewIds); } @Override @@ -54,7 +54,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy + Constants.UNDERLINE + a.getSchemaElement().getName(); } - public void detectByStep(QueryContext queryContext, Set existResults, Set detectModelIds, + public void detectByStep(QueryContext queryContext, Set existResults, Set detectViewIds, Integer startIndex, Integer index, int offset) { String detectSegment = queryContext.getQueryText().substring(startIndex, index); if (StringUtils.isBlank(detectSegment)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/EmbeddingMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/EmbeddingMapper.java index 8182c4c8b..fb4854579 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/EmbeddingMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/EmbeddingMapper.java @@ -1,20 +1,18 @@ package com.tencent.supersonic.chat.core.mapper; import com.alibaba.fastjson.JSONObject; -import com.tencent.supersonic.chat.core.pojo.QueryContext; -import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.chat.core.pojo.QueryContext; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.embedding.Retrieval; import java.util.List; import java.util.Objects; - import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; /*** * A mapper that recognizes schema elements with vector embedding. @@ -39,15 +37,11 @@ public class EmbeddingMapper extends BaseMapper { SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()), SchemaElement.class); - if (Objects.isNull(matchResult.getMetadata())) { + Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId")); + if (Objects.isNull(viewId)) { continue; } - String modelIdStr = matchResult.getMetadata().get("modelId"); - if (StringUtils.isBlank(modelIdStr)) { - continue; - } - long modelId = Long.parseLong(modelIdStr); - schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId, + schemaElement = getSchemaElement(viewId, schemaElement.getType(), elementId, queryContext.getSemanticSchema()); if (schemaElement == null) { continue; @@ -60,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper { .detectWord(matchResult.getDetectWord()) .build(); //3. add to mapInfo - addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch); + addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch); } } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/EmbeddingMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/EmbeddingMatchStrategy.java index a4461f257..790f1e73b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/EmbeddingMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/EmbeddingMatchStrategy.java @@ -4,15 +4,13 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult; import com.tencent.supersonic.chat.core.pojo.QueryContext; -import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQuery; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; -import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; +import com.tencent.supersonic.headless.server.service.MetaEmbeddingService; +import java.util.ArrayList; import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -36,9 +34,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { private OptimizationConfig optimizationConfig; @Autowired - private EmbeddingConfig embeddingConfig; - - private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); + private MetaEmbeddingService metaEmbeddingService; @Override public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) { @@ -52,7 +48,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { } @Override - protected void detectByBatch(QueryContext queryContext, Set results, Set detectModelIds, + protected void detectByBatch(QueryContext queryContext, Set results, Set detectViewIds, Set detectSegments) { List queryTextsList = detectSegments.stream() @@ -66,51 +62,29 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { optimizationConfig.getEmbeddingMapperBatch()); for (List queryTextsSub : queryTextsSubList) { - detectByQueryTextsSub(results, detectModelIds, queryTextsSub); + detectByQueryTextsSub(results, detectViewIds, queryTextsSub); } } - private void detectByQueryTextsSub(Set results, Set detectModelIds, + private void detectByQueryTextsSub(Set results, Set detectViewIds, List queryTextsSub) { int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber(); Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold(); - Map filterCondition = null; // step1. build query params - // if only one modelId, add to filterCondition - if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) { - filterCondition = new HashMap<>(); - filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString()); - } - - RetrieveQuery retrieveQuery = RetrieveQuery.builder() - .queryTextsList(queryTextsSub) - .filterCondition(filterCondition) - .queryEmbeddings(null) - .build(); + RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); // step2. retrieveQuery by detectSegment - List retrieveQueryResults = s2EmbeddingStore.retrieveQuery( - embeddingConfig.getMetaCollectionName(), retrieveQuery, embeddingNumber); + List retrieveQueryResults = metaEmbeddingService.retrieveQuery( + new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber); if (CollectionUtils.isEmpty(retrieveQueryResults)) { return; } - // step3. build EmbeddingResults. filter by modelId + // step3. build EmbeddingResults List collect = retrieveQueryResults.stream() .map(retrieveQueryResult -> { List retrievals = retrieveQueryResult.getRetrieval(); if (CollectionUtils.isNotEmpty(retrievals)) { retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue()); - if (CollectionUtils.isNotEmpty(detectModelIds)) { - retrievals.removeIf(retrieval -> { - String modelIdStr = retrieval.getMetadata().get("modelId").toString(); - if (StringUtils.isBlank(modelIdStr)) { - return true; - } - //return detectModelIds.contains(Long.parseLong(modelIdStr)); - Double modelId = Double.parseDouble(modelIdStr); - return detectModelIds.contains(modelId.longValue()); - }); - } } return retrieveQueryResult; }) @@ -121,6 +95,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { BeanUtils.copyProperties(retrieval, embeddingResult); embeddingResult.setDetectWord(retrieveQueryResult.getQuery()); embeddingResult.setName(retrieval.getQuery()); + Map convertedMap = retrieval.getMetadata().entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString())); + embeddingResult.setMetadata(convertedMap); return embeddingResult; })) .collect(Collectors.toList()); @@ -135,7 +112,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { } @Override - public void detectByStep(QueryContext queryContext, Set existResults, Set detectModelIds, + public void detectByStep(QueryContext queryContext, Set existResults, Set detectViewIds, Integer startIndex, Integer index, int offset) { return; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/HanlpDictMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/HanlpDictMatchStrategy.java index 923b1d157..6be8bc976 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/HanlpDictMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/HanlpDictMatchStrategy.java @@ -36,15 +36,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { @Override public Map> match(QueryContext queryContext, List terms, - Set detectModelIds) { + Set detectViewIds) { String text = queryContext.getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; } - log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds); + log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds); - List detects = detect(queryContext, terms, detectModelIds); + List detects = detect(queryContext, terms, detectViewIds); Map> result = new HashMap<>(); result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); @@ -57,7 +57,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { && existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length(); } - public void detectByStep(QueryContext queryContext, Set existResults, Set detectModelIds, + public void detectByStep(QueryContext queryContext, Set existResults, Set detectViewIds, Integer startIndex, Integer index, int offset) { String text = queryContext.getQueryText(); String detectSegment = text.substring(startIndex, index); @@ -65,11 +65,10 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { // step1. pre search Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize(); LinkedHashSet hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, - detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new)); + detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new)); // step2. suffix search LinkedHashSet suffixHanlpMapResults = SearchService.suffixSearch(detectSegment, - oneDetectionMaxSize, detectModelIds).stream() - .collect(Collectors.toCollection(LinkedHashSet::new)); + oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new)); hanlpMapResults.addAll(suffixHanlpMapResults); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/MatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/MatchStrategy.java index 6c064df99..6987600de 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/MatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/MatchStrategy.java @@ -13,6 +13,6 @@ import java.util.Set; */ public interface MatchStrategy { - Map> match(QueryContext queryContext, List terms, Set detectModelId); + Map> match(QueryContext queryContext, List terms, Set detectViewIds); } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/SearchMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/SearchMatchStrategy.java index 13f264534..a4a6b7565 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/SearchMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/SearchMatchStrategy.java @@ -27,7 +27,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy { @Override public Map> match(QueryContext queryContext, List originals, - Set detectModelIds) { + Set detectViewIds) { String text = queryContext.getQueryText(); Map regOffsetToLength = getRegOffsetToLength(originals); @@ -52,9 +52,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy { if (StringUtils.isNotEmpty(detectSegment)) { List hanlpMapResults = SearchService.prefixSearch(detectSegment, - SearchService.SEARCH_SIZE, detectModelIds); + SearchService.SEARCH_SIZE, detectViewIds); List suffixHanlpMapResults = SearchService.suffixSearch( - detectSegment, SEARCH_SIZE, detectModelIds); + detectSegment, SEARCH_SIZE, detectViewIds); hanlpMapResults.addAll(suffixHanlpMapResults); // remove entity name where search hanlpMapResults = hanlpMapResults.stream().filter(entry -> { @@ -88,7 +88,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy { } @Override - public void detectByStep(QueryContext queryContext, Set results, Set detectModelIds, + public void detectByStep(QueryContext queryContext, Set results, Set detectViewIds, Integer startIndex, Integer i, int offset) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/SearchServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/SearchServiceImpl.java index d730f2043..2ccf64e4a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/SearchServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/SearchServiceImpl.java @@ -97,12 +97,12 @@ public class SearchServiceImpl implements SearchService { List originals = knowledgeService.getTerms(queryText); log.info("hanlp parse result: {}", originals); MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); - Set detectModelIds = mapperHelper.getViewIds(queryReq.getModelId(), agentService.getAgent(agentId)); + Set detectViewIds = mapperHelper.getViewIds(queryReq.getModelId(), agentService.getAgent(agentId)); QueryContext queryContext = new QueryContext(); BeanUtils.copyProperties(queryReq, queryContext); Map> regTextMap = - searchMatchStrategy.match(queryContext, originals, detectModelIds); + searchMatchStrategy.match(queryContext, originals, detectViewIds); regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue())); // 4.get the most matching data diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/DataItem.java b/common/src/main/java/com/tencent/supersonic/common/pojo/DataItem.java index 0cdf31708..955ce0b87 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/DataItem.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/DataItem.java @@ -18,7 +18,7 @@ public class DataItem { private TypeEnums type; - private Long modelId; + private String modelId; private String defaultAgg; diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java index 1be07160e..f45dcf869 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java @@ -1,9 +1,9 @@ package com.tencent.supersonic.common.util.embedding; -import com.tencent.supersonic.common.pojo.enums.DictWordType; -import lombok.Data; - +import com.google.common.base.Objects; +import com.tencent.supersonic.common.pojo.Constants; import java.util.Map; +import lombok.Data; import org.apache.commons.lang3.StringUtils; @Data @@ -17,11 +17,30 @@ public class Retrieval { protected Map metadata; - public static Long getLongId(String id) { - if (StringUtils.isBlank(id)) { + public static Long getLongId(Object id) { + if (id == null || StringUtils.isBlank(id.toString())) { return null; } - String[] split = id.split(DictWordType.NATURE_SPILT); + String[] split = id.toString().split(Constants.UNDERLINE); return Long.parseLong(split[0]); } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Retrieval retrieval = (Retrieval) o; + return Double.compare(retrieval.distance, distance) == 0 && Objects.equal(id, + retrieval.id) && Objects.equal(query, retrieval.query) + && Objects.equal(metadata, retrieval.metadata); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, distance, query, metadata); + } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java index 38847d341..d0a292c0e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java @@ -2,12 +2,15 @@ package com.tencent.supersonic.headless.server.listener; import com.alibaba.fastjson.JSONObject; import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DataEvent; -import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -16,10 +19,6 @@ import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - @Component @Slf4j public class MetaEmbeddingListener implements ApplicationListener { @@ -38,12 +37,13 @@ public class MetaEmbeddingListener implements ApplicationListener { if (CollectionUtils.isEmpty(event.getDataItems())) { return; } + List embeddingQueries = event.getDataItems() .stream() .map(dataItem -> { EmbeddingQuery embeddingQuery = new EmbeddingQuery(); embeddingQuery.setQueryId( - dataItem.getId().toString() + DictWordType.NATURE_SPILT + dataItem.getId().toString() + Constants.UNDERLINE + dataItem.getType().name().toLowerCase()); embeddingQuery.setQuery(dataItem.getName()); Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/MetaEmbeddingService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/MetaEmbeddingService.java new file mode 100644 index 000000000..315f940a2 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/MetaEmbeddingService.java @@ -0,0 +1,11 @@ +package com.tencent.supersonic.headless.server.service; + +import com.tencent.supersonic.common.util.embedding.RetrieveQuery; +import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import java.util.List; + +public interface MetaEmbeddingService { + + List retrieveQuery(List viewIds, RetrieveQuery retrieveQuery, int num); + +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ViewService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ViewService.java index 7bf367c86..6b86a65fd 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ViewService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ViewService.java @@ -19,6 +19,8 @@ public interface ViewService { void delete(Long id, User user); + List getViewListByCache(MetaFilter metaFilter); + List getViews(User user); List getViewsInheritAuth(User user, Long domainId); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java index bc5288adf..78aaebcc0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java @@ -7,6 +7,7 @@ import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageInfo; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.DataItem; import com.tencent.supersonic.common.pojo.ModelRela; @@ -72,11 +73,11 @@ public class DimensionServiceImpl implements DimensionService { public DimensionServiceImpl(DimensionRepository dimensionRepository, - ModelService modelService, - ChatGptHelper chatGptHelper, - DatabaseService databaseService, - ModelRelaService modelRelaService, - ViewService viewService) { + ModelService modelService, + ChatGptHelper chatGptHelper, + DatabaseService databaseService, + ModelRelaService modelRelaService, + ViewService viewService) { this.modelService = modelService; this.dimensionRepository = dimensionRepository; this.chatGptHelper = chatGptHelper; @@ -129,8 +130,8 @@ public class DimensionServiceImpl implements DimensionService { DimensionConverter.convert(dimensionDO, dimensionReq); dimensionRepository.updateDimension(dimensionDO); if (!oldName.equals(dimensionDO.getName())) { - sendEvent(DataItem.builder().modelId(dimensionDO.getModelId()).newName(dimensionReq.getName()) - .name(oldName).type(TypeEnums.DIMENSION) + sendEvent(DataItem.builder().modelId(dimensionDO.getModelId() + Constants.UNDERLINE) + .newName(dimensionReq.getName()).name(oldName).type(TypeEnums.DIMENSION) .id(dimensionDO.getId()).build(), EventType.UPDATE); } } @@ -264,7 +265,7 @@ public class DimensionServiceImpl implements DimensionService { } private List convertList(List dimensionDOS, - Map modelRespMap) { + Map modelRespMap) { List dimensionResps = Lists.newArrayList(); if (!CollectionUtils.isEmpty(dimensionDOS)) { dimensionResps = dimensionDOS.stream() @@ -364,9 +365,9 @@ public class DimensionServiceImpl implements DimensionService { } private void sendEventBatch(List dimensionDOS, EventType eventType) { - List dataItems = dimensionDOS.stream().map(dimensionDO -> - DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName()) - .modelId(dimensionDO.getModelId()).type(TypeEnums.DIMENSION).build()) + List dataItems = dimensionDOS.stream() + .map(dimensionDO -> DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName()) + .modelId(dimensionDO.getModelId() + Constants.UNDERLINE).type(TypeEnums.DIMENSION).build()) .collect(Collectors.toList()); eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType)); } @@ -376,10 +377,4 @@ public class DimensionServiceImpl implements DimensionService { Lists.newArrayList(dataItem), eventType)); } - private DataItem getDataItem(DimensionDO dimensionDO) { - return DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName()) - .bizName(dimensionDO.getBizName()) - .modelId(dimensionDO.getModelId()).type(TypeEnums.DIMENSION).build(); - } - } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetaEmbeddingServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetaEmbeddingServiceImpl.java new file mode 100644 index 000000000..ba0052d67 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetaEmbeddingServiceImpl.java @@ -0,0 +1,114 @@ +package com.tencent.supersonic.headless.server.service.impl; + +import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.enums.StatusEnum; +import com.tencent.supersonic.common.util.ComponentFactory; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.common.util.embedding.RetrieveQuery; +import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; +import com.tencent.supersonic.headless.api.pojo.response.ViewResp; +import com.tencent.supersonic.headless.server.pojo.MetaFilter; +import com.tencent.supersonic.headless.server.service.MetaEmbeddingService; +import com.tencent.supersonic.headless.server.service.ViewService; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +@Service +@Slf4j +public class MetaEmbeddingServiceImpl implements MetaEmbeddingService { + + private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); + @Autowired + private EmbeddingConfig embeddingConfig; + + @Autowired + private ViewService viewService; + + @Override + public List retrieveQuery(List viewIds, RetrieveQuery retrieveQuery, int num) { + // viewIds->modelIds + MetaFilter metaFilter = new MetaFilter(); + metaFilter.setStatus(StatusEnum.ONLINE.getCode()); + metaFilter.setIds(viewIds); + List viewListByCache = viewService.getViewListByCache(metaFilter); + Set allModels = getModels(viewListByCache); + + Map> modelIdToViewIds = viewListByCache.stream() + .flatMap(viewResp -> viewResp.getAllModels().stream() + .map(modelId -> Pair.of(modelId, viewResp.getId()))) + .collect(Collectors.groupingBy(Pair::getLeft, Collectors.mapping(Pair::getRight, Collectors.toList()))); + + if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) { + Map filterCondition = new HashMap<>(); + filterCondition.put("modelId", allModels.stream().findFirst().get().toString()); + retrieveQuery.setFilterCondition(filterCondition); + } + + String collectionName = embeddingConfig.getMetaCollectionName(); + List resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery, num); + if (CollectionUtils.isEmpty(resultList)) { + return new ArrayList<>(); + } + //filter by modelId + if (CollectionUtils.isEmpty(allModels)) { + return resultList; + } + return resultList.stream() + .map(retrieveQueryResult -> { + List retrievals = retrieveQueryResult.getRetrieval(); + if (CollectionUtils.isEmpty(retrievals)) { + return retrieveQueryResult; + } + //filter by modelId + retrievals.removeIf(retrieval -> { + Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId")); + if (Objects.isNull(modelId)) { + return CollectionUtils.isEmpty(allModels); + } + return !allModels.contains(modelId); + }); + //add viewId + retrievals = retrievals.stream().flatMap(retrieval -> { + Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId")); + List viewIdsByModelId = modelIdToViewIds.get(modelId); + if (!CollectionUtils.isEmpty(viewIdsByModelId)) { + Set result = new HashSet<>(); + for (Long viewId : viewIdsByModelId) { + Retrieval retrievalNew = new Retrieval(); + BeanUtils.copyProperties(retrieval, retrievalNew); + retrievalNew.getMetadata().putIfAbsent("viewId", viewId + Constants.UNDERLINE); + result.add(retrievalNew); + } + return result.stream(); + } + Set result = new HashSet<>(); + result.add(retrieval); + return result.stream(); + }).collect(Collectors.toList()); + retrieveQueryResult.setRetrieval(retrievals); + return retrieveQueryResult; + }) + .filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval())) + .collect(Collectors.toList()); + } + + private Set getModels(List viewListByCache) { + return viewListByCache.stream() + .flatMap(viewResp -> viewResp.getAllModels().stream()) + .collect(Collectors.toSet()); + } +} \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java index 794f13fc3..e08383a93 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java @@ -6,6 +6,7 @@ import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageInfo; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.DataItem; import com.tencent.supersonic.common.pojo.enums.AuthType; @@ -259,8 +260,8 @@ public class MetricServiceImpl implements MetricService { metricFilter.setModelIds(Lists.newArrayList(modelId)); List metricResps = getMetrics(metricFilter); return metricResps.stream().filter(metricResp -> - MetricDefineType.FIELD.equals(metricResp.getMetricDefineType()) - || MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) + MetricDefineType.FIELD.equals(metricResp.getMetricDefineType()) + || MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) .collect(Collectors.toList()); } @@ -450,8 +451,8 @@ public class MetricServiceImpl implements MetricService { new HashMap<>(), Lists.newArrayList()); return DataItem.builder().id(metricDO.getId()).name(metricDO.getName()) .bizName(metricDO.getBizName()) - .modelId(metricDO.getModelId()).type(TypeEnums.METRIC) - .defaultAgg(metricResp.getDefaultAgg()).build(); + .modelId(metricDO.getModelId() + Constants.UNDERLINE) + .type(TypeEnums.METRIC).defaultAgg(metricResp.getDefaultAgg()).build(); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ViewServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ViewServiceImpl.java index b1adacfa6..5b1e6705e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ViewServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ViewServiceImpl.java @@ -3,6 +3,8 @@ package com.tencent.supersonic.headless.server.service.impl; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.AuthType; @@ -19,23 +21,26 @@ import com.tencent.supersonic.headless.server.persistence.mapper.ViewDOMapper; import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.ViewService; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import org.springframework.util.CollectionUtils; - import java.util.Arrays; import java.util.Comparator; import java.util.Date; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; @Service public class ViewServiceImpl extends ServiceImpl implements ViewService { + protected final Cache> viewSchemaCache = + CacheBuilder.newBuilder().expireAfterWrite(30, TimeUnit.SECONDS).build(); + @Autowired private DomainService domainService; @@ -153,4 +158,14 @@ public class ViewServiceImpl return admins.contains(userName) || viewResp.getCreatedBy().equals(userName); } + @Override + public List getViewListByCache(MetaFilter metaFilter) { + List viewList = viewSchemaCache.getIfPresent(metaFilter); + if (CollectionUtils.isEmpty(viewList)) { + viewList = getViewList(metaFilter); + viewSchemaCache.put(metaFilter, viewList); + } + return viewList; + } + }