diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java index 5715979a7..ab37a01ad 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java @@ -42,4 +42,18 @@ public class OptimizationConfig { @Value("${user.s2ql.switch:false}") private boolean useS2qlSwitch; + @Value("${embedding.mapper.word.min:2}") + private int embeddingMapperWordMin; + + @Value("${embedding.mapper.number:10}") + private int embeddingMapperNumber; + + @Value("${embedding.mapper.round.number:10}") + private int embeddingMapperRoundNumber; + + @Value("${embedding.mapper.sum.number:10}") + private int embeddingMapperSumNumber; + + @Value("${embedding.mapper.distance.threshold:0.3}") + private Double embeddingMapperDistanceThreshold; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java index f87bc52c3..48440a70a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java @@ -1,13 +1,26 @@ package com.tencent.supersonic.chat.mapper; +import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.chat.api.component.SchemaMapper; +import com.tencent.supersonic.chat.api.pojo.ModelSchema; import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.chat.service.SemanticService; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.knowledge.utils.NatureHelper; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; /** * base Mapper @@ -19,12 +32,17 @@ public abstract class BaseMapper implements SchemaMapper { public void map(QueryContext queryContext) { String simpleName = this.getClass().getSimpleName(); + long startTime = System.currentTimeMillis(); + log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo()); - log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo()); + try { + work(queryContext); + } catch (Exception e) { + log.error("work error", e); + } - work(queryContext); - - log.debug("after {},mapInfo:{}", simpleName, queryContext.getMapInfo()); + long cost = System.currentTimeMillis() - startTime; + log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo()); } public abstract void work(QueryContext queryContext); @@ -38,4 +56,63 @@ public abstract class BaseMapper implements SchemaMapper { } schemaElementMatches.add(schemaElementMatch); } + + public Set getModelIds(QueryContext queryContext) { + return ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest()); + } + + public List filterByModelIds(List terms, Set detectModelIds) { + logTerms(terms); + if (CollectionUtils.isNotEmpty(detectModelIds)) { + terms = terms.stream().filter(term -> { + Long modelId = NatureHelper.getModelId(term.getNature().toString()); + if (Objects.nonNull(modelId)) { + return detectModelIds.contains(modelId); + } + return false; + }).collect(Collectors.toList()); + log.info("terms filter by modelIds:{}", detectModelIds); + logTerms(terms); + } + return terms; + } + + public void logTerms(List terms) { + if (CollectionUtils.isEmpty(terms)) { + return; + } + for (Term term : terms) { + log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency()); + } + } + + public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) { + SchemaElement element = new SchemaElement(); + SemanticService schemaService = ContextUtils.getBean(SemanticService.class); + ModelSchema modelSchema = schemaService.getModelSchema(modelId); + if (Objects.isNull(modelSchema)) { + return null; + } + SchemaElement elementDb = modelSchema.getElement(elementType, elementID); + if (Objects.isNull(elementDb)) { + log.info("element is null, elementType:{},elementID:{}", elementType, elementID); + return null; + } + BeanUtils.copyProperties(elementDb, element); + element.setAlias(getAlias(elementDb)); + return element; + } + + public List getAlias(SchemaElement element) { + if (!SchemaElementType.VALUE.equals(element.getType())) { + return element.getAlias(); + } + if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty( + element.getName())) { + return element.getAlias().stream() + .filter(aliasItem -> aliasItem.contains(element.getName())) + .collect(Collectors.toList()); + } + return element.getAlias(); + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java new file mode 100644 index 000000000..e2265c01b --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java @@ -0,0 +1,121 @@ +package com.tencent.supersonic.chat.mapper; + +import com.hankcs.hanlp.seg.common.Term; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +/** + * match strategy implement + */ +@Service +@Slf4j +public abstract class BaseMatchStrategy implements MatchStrategy { + + @Autowired + private MapperHelper mapperHelper; + + + @Override + public Map> match(QueryReq queryReq, List terms, Set detectModelIds) { + String text = queryReq.getQueryText(); + if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { + return null; + } + + log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds); + + List detects = detect(queryReq, terms, detectModelIds); + Map> result = new HashMap<>(); + + result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); + return result; + } + + public List detect(QueryReq queryReq, List terms, Set detectModelIds) { + Map regOffsetToLength = getRegOffsetToLength(terms); + String text = queryReq.getQueryText(); + Set results = new HashSet<>(); + + for (Integer index = 0; index <= text.length() - 1; ) { + + for (Integer i = index; i <= text.length(); ) { + int offset = mapperHelper.getStepOffset(terms, index); + i = mapperHelper.getStepIndex(regOffsetToLength, i); + if (i <= text.length()) { + detectByStep(queryReq, results, detectModelIds, index, i, offset); + } + } + index = mapperHelper.getStepIndex(regOffsetToLength, index); + } + return new ArrayList<>(results); + } + + public Map getRegOffsetToLength(List terms) { + return terms.stream().sorted(Comparator.comparing(Term::length)) + .collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2)); + } + + public void selectResultInOneRound(Set existResults, List oneRoundResults) { + if (CollectionUtils.isEmpty(oneRoundResults)) { + return; + } + for (T oneRoundResult : oneRoundResults) { + if (existResults.contains(oneRoundResult)) { + boolean isDeleted = existResults.removeIf( + existResult -> { + boolean delete = needDelete(oneRoundResult, existResult); + if (delete) { + log.info("deleted existResult:{}", existResult); + } + return delete; + } + ); + if (isDeleted) { + log.info("deleted, add oneRoundResult:{}", oneRoundResult); + existResults.add(oneRoundResult); + } + } else { + existResults.add(oneRoundResult); + } + } + } + + public List getMatches(QueryReq queryReq, List terms, Set detectModelIds) { + Map> matchResult = match(queryReq, terms, detectModelIds); + List matches = new ArrayList<>(); + if (Objects.isNull(matchResult)) { + return matches; + } + Optional> first = matchResult.entrySet().stream() + .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) + .map(entry -> entry.getValue()).findFirst(); + + if (first.isPresent()) { + matches = first.get(); + } + return matches; + } + + public abstract boolean needDelete(T oneRoundResult, T existResult); + + public abstract String getMapKey(T a); + + public abstract void detectByStep(QueryReq queryReq, Set results, Set detectModelIds, Integer startIndex, + Integer index, int offset); + + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java index 28866424f..3d1d02458 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java @@ -1,7 +1,19 @@ package com.tencent.supersonic.chat.mapper; +import com.alibaba.fastjson.JSONObject; +import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult; +import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder; +import com.tencent.supersonic.knowledge.utils.HanlpHelper; +import java.util.List; +import java.util.Set; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; /*** * A mapper that is capable of semantic understanding of text. @@ -13,10 +25,44 @@ public class EmbeddingMapper extends BaseMapper { public void work(QueryContext queryContext) { //1. query from embedding by queryText + String queryText = queryContext.getRequest().getQueryText(); + List terms = HanlpHelper.getTerms(queryText); + + EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); + + Set detectModelIds = getModelIds(queryContext); + + terms = filterByModelIds(terms, detectModelIds); + + List matchResults = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds); + + HanlpHelper.transLetterOriginal(matchResults); //2. build SchemaElementMatch by info + MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); + for (EmbeddingResult matchResult : matchResults) { + Long elementId = Retrieval.getLongId(matchResult.getId()); - //3. add to mapInfo + SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()), + SchemaElement.class); + String modelIdStr = matchResult.getMetadata().get("modelId"); + if (StringUtils.isBlank(modelIdStr)) { + continue; + } + long modelId = Long.parseLong(modelIdStr); + + schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId); + + SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() + .element(schemaElement) + .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) + .word(matchResult.getName()) + .similarity(mapperHelper.getSimilarity(matchResult.getName(), matchResult.getDetectWord())) + .detectWord(matchResult.getDetectWord()) + .build(); + //3. add to mapInfo + addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch); + } } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java new file mode 100644 index 000000000..280beef89 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java @@ -0,0 +1,116 @@ +package com.tencent.supersonic.chat.mapper; + +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.config.OptimizationConfig; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.common.util.embedding.RetrieveQuery; +import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult; +import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +/** + * match strategy implement + */ +@Service +@Slf4j +public class EmbeddingMatchStrategy extends BaseMatchStrategy { + + @Autowired + private OptimizationConfig optimizationConfig; + @Autowired + private EmbeddingUtils embeddingUtils; + + @Override + public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) { + return getMapKey(oneRoundResult).equals(getMapKey(existResult)) + && existResult.getDistance() > oneRoundResult.getDistance(); + } + + @Override + public String getMapKey(EmbeddingResult a) { + return a.getName() + Constants.UNDERLINE + a.getId(); + } + + public void detectByStep(QueryReq queryReq, Set existResults, Set detectModelIds, + Integer startIndex, Integer index, int offset) { + String detectSegment = queryReq.getQueryText().substring(startIndex, index); + // step1. build query params + if (StringUtils.isBlank(detectSegment) + || detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMin()) { + return; + } + int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber(); + Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold(); + Map filterCondition = null; + + // if only one modelId, add to filterCondition + if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) { + filterCondition = new HashMap<>(); + filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString()); + } + + RetrieveQuery retrieveQuery = RetrieveQuery.builder() + .queryTextsList(Collections.singletonList(detectSegment)) + .filterCondition(filterCondition) + .queryEmbeddings(null) + .build(); + // step2. retrieveQuery by detectSegment + List retrieveQueryResults = embeddingUtils.retrieveQuery( + MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber); + + if (CollectionUtils.isEmpty(retrieveQueryResults)) { + return; + } + // step3. build EmbeddingResults. filter by modelId + List collect = retrieveQueryResults.stream() + .map(retrieveQueryResult -> { + List retrievals = retrieveQueryResult.getRetrieval(); + if (CollectionUtils.isNotEmpty(retrievals)) { + retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue()); + if (CollectionUtils.isNotEmpty(detectModelIds)) { + retrievals.removeIf(retrieval -> { + String modelIdStr = retrieval.getMetadata().get("modelId"); + if (StringUtils.isBlank(modelIdStr)) { + return true; + } + return detectModelIds.contains(Long.parseLong(modelIdStr)); + }); + } + } + return retrieveQueryResult; + }) + .filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval())) + .flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream() + .map(retrieval -> { + EmbeddingResult embeddingResult = new EmbeddingResult(); + BeanUtils.copyProperties(retrieval, embeddingResult); + embeddingResult.setDetectWord(detectSegment); + embeddingResult.setName(retrieval.getQuery()); + return embeddingResult; + })) + .collect(Collectors.toList()); + + // step4. select mapResul in one round + int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber(); + List oneRoundResults = collect.stream() + .sorted(Comparator.comparingDouble(EmbeddingResult::getDistance)) + .limit(roundNumber) + .collect(Collectors.toList()); + selectResultInOneRound(existResults, oneRoundResults); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpDictMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpDictMapper.java index 63d4fd1d2..554518a84 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpDictMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpDictMapper.java @@ -1,28 +1,22 @@ package com.tencent.supersonic.chat.mapper; import com.hankcs.hanlp.seg.common.Term; -import com.tencent.supersonic.chat.api.pojo.ModelSchema; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; -import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.knowledge.dictionary.MapResult; +import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult; import com.tencent.supersonic.knowledge.utils.HanlpHelper; import com.tencent.supersonic.knowledge.utils.NatureHelper; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.BeanUtils; /*** * A mapper capable of prefix and suffix similarity parsing for @@ -37,45 +31,23 @@ public class HanlpDictMapper extends BaseMapper { String queryText = queryContext.getRequest().getQueryText(); List terms = HanlpHelper.getTerms(queryText); - QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class); + HanlpMatchStrategy matchStrategy = ContextUtils.getBean(HanlpMatchStrategy.class); - Set detectModelIds = ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest()); + Set detectModelIds = getModelIds(queryContext); terms = filterByModelIds(terms, detectModelIds); - Map> matchResult = matchStrategy.match(queryContext.getRequest(), terms, - detectModelIds); - - List matches = getMatches(matchResult); + List matches = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds); HanlpHelper.transLetterOriginal(matches); convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms); } - private List filterByModelIds(List terms, Set detectModelIds) { - for (Term term : terms) { - log.info("before word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency()); - } - if (CollectionUtils.isNotEmpty(detectModelIds)) { - terms = terms.stream().filter(term -> { - Long modelId = NatureHelper.getModelId(term.getNature().toString()); - if (Objects.nonNull(modelId)) { - return detectModelIds.contains(modelId); - } - return false; - }).collect(Collectors.toList()); - } - for (Term term : terms) { - log.info("after filter word:{},nature:{},frequency:{}", term.word, term.nature.toString(), - term.getFrequency()); - } - return terms; - } - - private void convertTermsToSchemaMapInfo(List mapResults, SchemaMapInfo schemaMap, List terms) { - if (CollectionUtils.isEmpty(mapResults)) { + private void convertTermsToSchemaMapInfo(List hanlpMapResults, SchemaMapInfo schemaMap, + List terms) { + if (CollectionUtils.isEmpty(hanlpMapResults)) { return; } @@ -83,8 +55,8 @@ public class HanlpDictMapper extends BaseMapper { Collectors.toMap(entry -> entry.getWord() + entry.getNature(), term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2)); - for (MapResult mapResult : mapResults) { - for (String nature : mapResult.getNatures()) { + for (HanlpMapResult hanlpMapResult : hanlpMapResults) { + for (String nature : hanlpMapResult.getNatures()) { Long modelId = NatureHelper.getModelId(nature); if (Objects.isNull(modelId)) { continue; @@ -93,33 +65,21 @@ public class HanlpDictMapper extends BaseMapper { if (Objects.isNull(elementType)) { continue; } - - SemanticService schemaService = ContextUtils.getBean(SemanticService.class); - ModelSchema modelSchema = schemaService.getModelSchema(modelId); - if (Objects.isNull(modelSchema)) { - return; - } - Long elementID = NatureHelper.getElementID(nature); - Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature); - - SchemaElement elementDb = modelSchema.getElement(elementType, elementID); - if (Objects.isNull(elementDb)) { - log.info("element is null, elementType:{},elementID:{}", elementType, elementID); + SchemaElement element = getSchemaElement(modelId, elementType, elementID); + if (element == null) { continue; } - SchemaElement element = new SchemaElement(); - BeanUtils.copyProperties(elementDb, element); - element.setAlias(getAlias(elementDb)); if (element.getType().equals(SchemaElementType.VALUE)) { - element.setName(mapResult.getName()); + element.setName(hanlpMapResult.getName()); } + Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature); SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() .element(element) .frequency(frequency) - .word(mapResult.getName()) - .similarity(mapResult.getSimilarity()) - .detectWord(mapResult.getDetectWord()) + .word(hanlpMapResult.getName()) + .similarity(hanlpMapResult.getSimilarity()) + .detectWord(hanlpMapResult.getDetectWord()) .build(); addToSchemaMap(schemaMap, modelId, schemaElementMatch); @@ -127,30 +87,5 @@ public class HanlpDictMapper extends BaseMapper { } } - private List getMatches(Map> matchResult) { - List matches = new ArrayList<>(); - if (Objects.isNull(matchResult)) { - return matches; - } - Optional> first = matchResult.entrySet().stream() - .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) - .map(entry -> entry.getValue()).findFirst(); - if (first.isPresent()) { - matches = first.get(); - } - return matches; - } - - public List getAlias(SchemaElement element) { - if (!SchemaElementType.VALUE.equals(element.getType())) { - return element.getAlias(); - } - if (CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(element.getName())) { - return element.getAlias().stream() - .filter(aliasItem -> aliasItem.contains(element.getName())) - .collect(Collectors.toList()); - } - return element.getAlias(); - } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpMatchStrategy.java new file mode 100644 index 000000000..96657e4e7 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpMatchStrategy.java @@ -0,0 +1,119 @@ +package com.tencent.supersonic.chat.mapper; + +import com.hankcs.hanlp.seg.common.Term; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.config.OptimizationConfig; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult; +import com.tencent.supersonic.knowledge.service.SearchService; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +/** + * match strategy implement + */ +@Service +@Slf4j +public class HanlpMatchStrategy extends BaseMatchStrategy { + + @Autowired + private MapperHelper mapperHelper; + + @Autowired + private OptimizationConfig optimizationConfig; + + @Override + public Map> match(QueryReq queryReq, List terms, Set detectModelIds) { + String text = queryReq.getQueryText(); + if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { + return null; + } + + log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds); + + List detects = detect(queryReq, terms, detectModelIds); + Map> result = new HashMap<>(); + + result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); + return result; + } + + @Override + public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) { + return getMapKey(oneRoundResult).equals(getMapKey(existResult)) + && existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length(); + } + + public void detectByStep(QueryReq queryReq, Set existResults, Set detectModelIds, + Integer startIndex, Integer index, int offset) { + String text = queryReq.getQueryText(); + Integer agentId = queryReq.getAgentId(); + String detectSegment = text.substring(startIndex, index); + + // step1. pre search + Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize(); + LinkedHashSet hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, + agentId, + detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new)); + // step2. suffix search + LinkedHashSet suffixHanlpMapResults = SearchService.suffixSearch(detectSegment, + oneDetectionMaxSize, + agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new)); + + hanlpMapResults.addAll(suffixHanlpMapResults); + + if (CollectionUtils.isEmpty(hanlpMapResults)) { + return; + } + // step3. merge pre/suffix result + hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length())) + .collect(Collectors.toCollection(LinkedHashSet::new)); + + // step4. filter by similarity + hanlpMapResults = hanlpMapResults.stream() + .filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName()) + >= mapperHelper.getThresholdMatch(term.getNatures())) + .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) + .collect(Collectors.toCollection(LinkedHashSet::new)); + + log.info("after isSimilarity parseResults:{}", hanlpMapResults); + + hanlpMapResults = hanlpMapResults.stream().map(parseResult -> { + parseResult.setOffset(offset); + parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName())); + return parseResult; + }).collect(Collectors.toCollection(LinkedHashSet::new)); + + // step5. take only one dimension or 10 metric/dimension value per rond. + List dimensionMetrics = hanlpMapResults.stream() + .filter(entry -> mapperHelper.existDimensionValues(entry.getNatures())) + .collect(Collectors.toList()) + .stream() + .limit(1) + .collect(Collectors.toList()); + + Integer oneDetectionSize = optimizationConfig.getOneDetectionSize(); + List oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize) + .collect(Collectors.toList()); + + if (CollectionUtils.isNotEmpty(dimensionMetrics)) { + oneRoundResults = dimensionMetrics; + } + // step6. select mapResul in one round + selectResultInOneRound(existResults, oneRoundResults); + } + + public String getMapKey(HanlpMapResult a) { + return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures()); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MapperHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MapperHelper.java index 53bdc704b..48056b733 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MapperHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MapperHelper.java @@ -1,11 +1,13 @@ package com.tencent.supersonic.chat.mapper; import com.hankcs.hanlp.algorithm.EditDistance; +import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.knowledge.utils.NatureHelper; +import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -39,10 +41,14 @@ public class MapperHelper { return index; } - public Integer getStepOffset(List termList, Integer index) { + + public Integer getStepOffset(List termList, Integer index) { + List offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset)) + .map(term -> term.getOffset()).collect(Collectors.toList()); + for (int j = 0; j < termList.size() - 1; j++) { - if (termList.get(j) <= index && termList.get(j + 1) > index) { - return termList.get(j); + if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) { + return offsetList.get(j); } } return index; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MatchStrategy.java index 481e9c128..db25d5f44 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MatchStrategy.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.mapper; import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; -import com.tencent.supersonic.knowledge.dictionary.MapResult; import java.util.List; import java.util.Map; import java.util.Set; @@ -10,8 +9,8 @@ import java.util.Set; /** * match strategy */ -public interface MatchStrategy { +public interface MatchStrategy { - Map> match(QueryReq queryReq, List terms, Set detectModelId); + Map> match(QueryReq queryReq, List terms, Set detectModelId); } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/QueryMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/QueryMatchStrategy.java deleted file mode 100644 index 2a89e98dc..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/QueryMatchStrategy.java +++ /dev/null @@ -1,163 +0,0 @@ -package com.tencent.supersonic.chat.mapper; - -import com.hankcs.hanlp.seg.common.Term; -import com.tencent.supersonic.chat.api.pojo.request.QueryReq; -import com.tencent.supersonic.chat.config.OptimizationConfig; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.knowledge.dictionary.MapResult; -import com.tencent.supersonic.knowledge.service.SearchService; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.compress.utils.Lists; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; - -/** - * match strategy implement - */ -@Service -@Slf4j -public class QueryMatchStrategy implements MatchStrategy { - - @Autowired - private MapperHelper mapperHelper; - - @Autowired - private OptimizationConfig optimizationConfig; - - @Override - public Map> match(QueryReq queryReq, List terms, Set detectModelIds) { - String text = queryReq.getQueryText(); - if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { - return null; - } - - Map regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length)) - .collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2)); - - List offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset)) - .map(term -> term.getOffset()).collect(Collectors.toList()); - - log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelIds:{}", terms, - regOffsetToLength, offsetList, detectModelIds); - - List detects = detect(queryReq, regOffsetToLength, offsetList, detectModelIds); - Map> result = new HashMap<>(); - - result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); - return result; - } - - private List detect(QueryReq queryReq, Map regOffsetToLength, List offsetList, - Set detectModelIds) { - String text = queryReq.getQueryText(); - List results = Lists.newArrayList(); - - for (Integer index = 0; index <= text.length() - 1; ) { - - Set mapResultRowSet = new LinkedHashSet(); - - for (Integer i = index; i <= text.length(); ) { - int offset = mapperHelper.getStepOffset(offsetList, index); - i = mapperHelper.getStepIndex(regOffsetToLength, i); - if (i <= text.length()) { - List mapResults = detectByStep(queryReq, detectModelIds, index, i, offset); - selectMapResultInOneRound(mapResultRowSet, mapResults); - } - } - index = mapperHelper.getStepIndex(regOffsetToLength, index); - results.addAll(mapResultRowSet); - } - return results; - } - - private void selectMapResultInOneRound(Set mapResultRowSet, List mapResults) { - for (MapResult mapResult : mapResults) { - if (mapResultRowSet.contains(mapResult)) { - boolean isDeleted = mapResultRowSet.removeIf( - entry -> { - boolean deleted = getMapKey(mapResult).equals(getMapKey(entry)) - && entry.getDetectWord().length() < mapResult.getDetectWord().length(); - if (deleted) { - log.info("deleted entry:{}", entry); - } - return deleted; - } - ); - if (isDeleted) { - log.info("deleted, add mapResult:{}", mapResult); - mapResultRowSet.add(mapResult); - } - } else { - mapResultRowSet.add(mapResult); - } - } - } - - private String getMapKey(MapResult a) { - return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures()); - } - - private List detectByStep(QueryReq queryReq, Set detectModelIds, Integer index, Integer i, - int offset) { - String text = queryReq.getQueryText(); - Integer agentId = queryReq.getAgentId(); - String detectSegment = text.substring(index, i); - - // step1. pre search - Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize(); - LinkedHashSet mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId, - detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new)); - // step2. suffix search - LinkedHashSet suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize, - agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new)); - - mapResults.addAll(suffixMapResults); - - if (CollectionUtils.isEmpty(mapResults)) { - return new ArrayList<>(); - } - // step3. merge pre/suffix result - mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length())) - .collect(Collectors.toCollection(LinkedHashSet::new)); - - // step4. filter by similarity - mapResults = mapResults.stream() - .filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName()) - >= mapperHelper.getThresholdMatch(term.getNatures())) - .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) - .collect(Collectors.toCollection(LinkedHashSet::new)); - - log.info("after isSimilarity parseResults:{}", mapResults); - - mapResults = mapResults.stream().map(parseResult -> { - parseResult.setOffset(offset); - parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName())); - return parseResult; - }).collect(Collectors.toCollection(LinkedHashSet::new)); - - // step5. take only one dimension or 10 metric/dimension value per rond. - List dimensionMetrics = mapResults.stream() - .filter(entry -> mapperHelper.existDimensionValues(entry.getNatures())) - .collect(Collectors.toList()) - .stream() - .limit(1) - .collect(Collectors.toList()); - - if (CollectionUtils.isNotEmpty(dimensionMetrics)) { - return dimensionMetrics; - } else { - return mapResults.stream().limit(optimizationConfig.getOneDetectionSize()).collect(Collectors.toList()); - } - } -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/SearchMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/SearchMatchStrategy.java index 5a0fdfba0..96192c723 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/SearchMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/SearchMatchStrategy.java @@ -4,7 +4,7 @@ import com.google.common.collect.Lists; import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.common.pojo.enums.DictWordType; -import com.tencent.supersonic.knowledge.dictionary.MapResult; +import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult; import com.tencent.supersonic.knowledge.service.SearchService; import java.util.List; import java.util.Map; @@ -20,17 +20,15 @@ import org.springframework.stereotype.Service; * match strategy implement */ @Service -public class SearchMatchStrategy implements MatchStrategy { +public class SearchMatchStrategy extends BaseMatchStrategy { private static final int SEARCH_SIZE = 3; @Override - public Map> match(QueryReq queryReq, List originals, Set detectModelIds) { + public Map> match(QueryReq queryReq, List originals, + Set detectModelIds) { String text = queryReq.getQueryText(); - Map regOffsetToLength = originals.stream() - .filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT)) - .collect(Collectors.toMap(Term::getOffset, value -> value.word.length(), - (value1, value2) -> value2)); + Map regOffsetToLength = getRegOffsetToLength(originals); List detectIndexList = Lists.newArrayList(); @@ -46,19 +44,19 @@ public class SearchMatchStrategy implements MatchStrategy { index++; } } - Map> regTextMap = new ConcurrentHashMap<>(); + Map> regTextMap = new ConcurrentHashMap<>(); detectIndexList.stream().parallel().forEach(detectIndex -> { String regText = text.substring(0, detectIndex); String detectSegment = text.substring(detectIndex); if (StringUtils.isNotEmpty(detectSegment)) { - List mapResults = SearchService.prefixSearch(detectSegment, + List hanlpMapResults = SearchService.prefixSearch(detectSegment, SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds); - List suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE, - queryReq.getAgentId(), detectModelIds); - mapResults.addAll(suffixMapResults); + List suffixHanlpMapResults = SearchService.suffixSearch( + detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds); + hanlpMapResults.addAll(suffixHanlpMapResults); // remove entity name where search - mapResults = mapResults.stream().filter(entry -> { + hanlpMapResults = hanlpMapResults.stream().filter(entry -> { List natures = entry.getNatures().stream() .filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType())) .collect(Collectors.toList()); @@ -71,10 +69,27 @@ public class SearchMatchStrategy implements MatchStrategy { .regText(regText) .detectSegment(detectSegment) .build(); - regTextMap.put(matchText, mapResults); + regTextMap.put(matchText, hanlpMapResults); } } ); return regTextMap; } + + @Override + public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) { + return false; + } + + @Override + public String getMapKey(HanlpMapResult a) { + return null; + } + + @Override + public void detectByStep(QueryReq queryReq, Set results, Set detectModelIds, + Integer startIndex, + Integer i, int offset) { + + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/SimilarMetricExecuteResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/SimilarMetricExecuteResponder.java index e56642982..aa9b2249a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/SimilarMetricExecuteResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/SimilarMetricExecuteResponder.java @@ -57,7 +57,7 @@ public class SimilarMetricExecuteResponder implements ExecuteResponder { metric.setOrder(metricOrder++); } for (Retrieval retrieval : retrievals) { - if (!metricIds.contains(retrieval.getId())) { + if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) { SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()), SchemaElement.class); if (retrieval.getMetadata().containsKey("modelId")) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index eeebfd760..98e54e1b2 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -47,7 +47,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.knowledge.dictionary.MapResult; +import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult; import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary; import com.tencent.supersonic.knowledge.service.SearchService; import com.tencent.supersonic.knowledge.utils.HanlpHelper; @@ -628,10 +628,10 @@ public class QueryServiceImpl implements QueryService { return terms.stream().map(term -> term.getWord()).collect(Collectors.toList()); } //search from prefixSearch - List mapResultList = SearchService.prefixSearch(dimensionValueReq.getValue(), + List hanlpMapResultList = SearchService.prefixSearch(dimensionValueReq.getValue(), 2000, dimensionValueReq.getAgentId(), detectModelIds); - HanlpHelper.transLetterOriginal(mapResultList); - return mapResultList.stream() + HanlpHelper.transLetterOriginal(hanlpMapResultList); + return hanlpMapResultList.stream() .filter(o -> { for (String nature : o.getNatures()) { Long elementID = NatureHelper.getElementID(nature); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java index eff21a725..c3d79f78b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java @@ -24,7 +24,7 @@ import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.SearchService; import com.tencent.supersonic.knowledge.utils.NatureHelper; import com.tencent.supersonic.knowledge.dictionary.DictWord; -import com.tencent.supersonic.knowledge.dictionary.MapResult; +import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.utils.HanlpHelper; @@ -94,11 +94,12 @@ public class SearchServiceImpl implements SearchService { MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); Set detectModelIds = mapperHelper.getModelIds(queryReq); - Map> regTextMap = searchMatchStrategy.match(queryReq, originals, detectModelIds); + Map> regTextMap = + searchMatchStrategy.match(queryReq, originals, detectModelIds); regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue())); // 4.get the most matching data - Optional>> mostSimilarSearchResult = regTextMap.entrySet() + Optional>> mostSimilarSearchResult = regTextMap.entrySet() .stream() .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) .reduce((entry1, entry2) -> @@ -109,7 +110,7 @@ public class SearchServiceImpl implements SearchService { if (!mostSimilarSearchResult.isPresent()) { return Lists.newArrayList(); } - Map.Entry> searchTextEntry = mostSimilarSearchResult.get(); + Map.Entry> searchTextEntry = mostSimilarSearchResult.get(); log.info("searchTextEntry:{},queryReq:{}", searchTextEntry, queryReq); Set searchResults = new LinkedHashSet(); @@ -275,9 +276,9 @@ public class SearchServiceImpl implements SearchService { * @param recommendTextListEntry * @return */ - private Map getNatureToNameMap(Map.Entry> recommendTextListEntry, + private Map getNatureToNameMap(Map.Entry> recommendTextListEntry, Set possibleModels) { - List recommendValues = recommendTextListEntry.getValue(); + List recommendValues = recommendTextListEntry.getValue(); return recommendValues.stream() .flatMap(entry -> entry.getNatures().stream() .filter(nature -> { @@ -288,26 +289,25 @@ public class SearchServiceImpl implements SearchService { return possibleModels.contains(model); }) .map(nature -> { - DictWord posDO = new DictWord(); - posDO.setWord(entry.getName()); - posDO.setNature(nature); - return posDO; - } - )).sorted(Comparator.comparingInt(a -> a.getWord().length())) + DictWord posDO = new DictWord(); + posDO.setWord(entry.getName()); + posDO.setNature(nature); + return posDO; + })).sorted(Comparator.comparingInt(a -> a.getWord().length())) .collect(Collectors.toMap(DictWord::getNature, DictWord::getWord, (value1, value2) -> value1, LinkedHashMap::new)); } private boolean searchMetricAndDimension(Set possibleModels, Map modelToName, - Map.Entry> searchTextEntry, Set searchResults) { + Map.Entry> searchTextEntry, Set searchResults) { boolean existMetric = false; log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry); MatchText matchText = searchTextEntry.getKey(); - List mapResults = searchTextEntry.getValue(); + List hanlpMapResults = searchTextEntry.getValue(); - for (MapResult mapResult : mapResults) { + for (HanlpMapResult hanlpMapResult : hanlpMapResults) { - List dimensionMetricClassIds = mapResult.getNatures().stream() + List dimensionMetricClassIds = hanlpMapResult.getNatures().stream() .map(nature -> new ModelWithSemanticType(NatureHelper.getModelId(nature), NatureHelper.convertToElementType(nature))) .filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList()); @@ -322,8 +322,8 @@ public class SearchServiceImpl implements SearchService { SearchResult searchResult = SearchResult.builder() .modelId(modelId) .modelName(modelToName.get(modelId)) - .recommend(matchText.getRegText() + mapResult.getName()) - .subRecommend(mapResult.getName()) + .recommend(matchText.getRegText() + hanlpMapResult.getName()) + .subRecommend(hanlpMapResult.getName()) .schemaElementType(semanticType) .build(); //visibility to filter metrics @@ -332,13 +332,13 @@ public class SearchServiceImpl implements SearchService { visibility = configService.getVisibilityByModelId(modelId); caffeineCache.put(modelId, visibility); } - if (!visibility.getBlackMetricNameList().contains(mapResult.getName()) - && !visibility.getBlackDimNameList().contains(mapResult.getName())) { + if (!visibility.getBlackMetricNameList().contains(hanlpMapResult.getName()) + && !visibility.getBlackDimNameList().contains(hanlpMapResult.getName())) { searchResults.add(searchResult); } } - log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", mapResult, dimensionMetricClassIds, - possibleModels); + log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", hanlpMapResult, + dimensionMetricClassIds, possibleModels); } log.info("searchMetricAndDimension searchResults:{}", searchResults); return existMetric; diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/mapper/match/QueryMatchStrategyTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/mapper/match/HanlpMatchStrategyTest.java similarity index 83% rename from chat/core/src/test/java/com/tencent/supersonic/chat/mapper/match/QueryMatchStrategyTest.java rename to chat/core/src/test/java/com/tencent/supersonic/chat/mapper/match/HanlpMatchStrategyTest.java index c32ee810b..34d913837 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/mapper/match/QueryMatchStrategyTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/mapper/match/HanlpMatchStrategyTest.java @@ -6,7 +6,7 @@ import org.junit.jupiter.api.Test; /** * MatchStrategyImplTest */ -class QueryMatchStrategyTest extends ContextTest { +class HanlpMatchStrategyTest extends ContextTest { @Test void match() { diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/EmbeddingResult.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/EmbeddingResult.java new file mode 100644 index 000000000..66040cdc6 --- /dev/null +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/EmbeddingResult.java @@ -0,0 +1,34 @@ +package com.tencent.supersonic.knowledge.dictionary; + +import com.google.common.base.Objects; +import java.util.Map; +import lombok.Data; +import lombok.ToString; + +@Data +@ToString +public class EmbeddingResult extends MapResult { + + private String id; + + private double distance; + + private Map metadata; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + EmbeddingResult that = (EmbeddingResult) o; + return Objects.equal(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hashCode(id); + } +} \ No newline at end of file diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/HanlpMapResult.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/HanlpMapResult.java new file mode 100644 index 000000000..0bab62d24 --- /dev/null +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/HanlpMapResult.java @@ -0,0 +1,44 @@ +package com.tencent.supersonic.knowledge.dictionary; + +import com.google.common.base.Objects; +import java.util.List; +import lombok.Data; +import lombok.ToString; + +@Data +@ToString +public class HanlpMapResult extends MapResult { + + private List natures; + private int offset = 0; + + private double similarity; + + public HanlpMapResult(String name, List natures, String detectWord) { + this.name = name; + this.natures = natures; + this.detectWord = detectWord; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + HanlpMapResult hanlpMapResult = (HanlpMapResult) o; + return Objects.equal(name, hanlpMapResult.name) && Objects.equal(natures, hanlpMapResult.natures); + } + + @Override + public int hashCode() { + return Objects.hashCode(name, natures); + } + + public void setOffset(int offset) { + this.offset = offset; + } + +} \ No newline at end of file diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/MapResult.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/MapResult.java index bcd644ae0..6eb248fc3 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/MapResult.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/MapResult.java @@ -1,8 +1,6 @@ package com.tencent.supersonic.knowledge.dictionary; -import com.google.common.base.Objects; import java.io.Serializable; -import java.util.List; import lombok.Data; import lombok.ToString; @@ -10,43 +8,6 @@ import lombok.ToString; @ToString public class MapResult implements Serializable { - private String name; - private List natures; - private int offset = 0; - - private double similarity; - - private String detectWord; - - public MapResult() { - - } - - public MapResult(String name, List natures, String detectWord) { - this.name = name; - this.natures = natures; - this.detectWord = detectWord; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - MapResult mapResult = (MapResult) o; - return Objects.equal(name, mapResult.name) && Objects.equal(natures, mapResult.natures); - } - - @Override - public int hashCode() { - return Objects.hashCode(name, natures); - } - - public void setOffset(int offset) { - this.offset = offset; - } - + protected String name; + protected String detectWord; } \ No newline at end of file diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/SearchService.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/SearchService.java index 1f854d12b..7db89361a 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/SearchService.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/SearchService.java @@ -7,7 +7,7 @@ import com.hankcs.hanlp.dictionary.CoreDictionary; import com.tencent.supersonic.knowledge.dictionary.DictWord; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.knowledge.dictionary.DictionaryAttributeUtil; -import com.tencent.supersonic.knowledge.dictionary.MapResult; +import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -38,17 +38,17 @@ public class SearchService { * @param key * @return */ - public static List prefixSearch(String key, int limit, Integer agentId, Set detectModelIds) { + public static List prefixSearch(String key, int limit, Integer agentId, Set detectModelIds) { return prefixSearch(key, limit, agentId, trie, detectModelIds); } - public static List prefixSearch(String key, int limit, Integer agentId, BinTrie> binTrie, - Set detectModelIds) { + public static List prefixSearch(String key, int limit, Integer agentId, + BinTrie> binTrie, Set detectModelIds) { Set>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds); return result.stream().map( entry -> { String name = entry.getKey().replace("#", " "); - return new MapResult(name, entry.getValue(), key); + return new HanlpMapResult(name, entry.getValue(), key); } ).sorted((a, b) -> -(b.getName().length() - a.getName().length())) .limit(SEARCH_SIZE) @@ -60,13 +60,13 @@ public class SearchService { * @param key * @return */ - public static List suffixSearch(String key, int limit, Integer agentId, Set detectModelIds) { + public static List suffixSearch(String key, int limit, Integer agentId, Set detectModelIds) { String reverseDetectSegment = StringUtils.reverse(key); return suffixSearch(reverseDetectSegment, limit, agentId, suffixTrie, detectModelIds); } - public static List suffixSearch(String key, int limit, Integer agentId, BinTrie> binTrie, - Set detectModelIds) { + public static List suffixSearch(String key, int limit, Integer agentId, + BinTrie> binTrie, Set detectModelIds) { Set>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds); return result.stream().map( entry -> { @@ -75,7 +75,7 @@ public class SearchService { .map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), "")) .collect(Collectors.toList()); name = StringUtils.reverse(name); - return new MapResult(name, natures, key); + return new HanlpMapResult(name, natures, key); } ).sorted((a, b) -> -(b.getName().length() - a.getName().length())) .limit(SEARCH_SIZE) diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/HanlpHelper.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/HanlpHelper.java index 07051e499..54702adec 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/HanlpHelper.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/HanlpHelper.java @@ -9,17 +9,16 @@ import com.hankcs.hanlp.seg.Segment; import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.knowledge.dictionary.DictWord; +import com.tencent.supersonic.knowledge.dictionary.HadoopFileIOAdapter; +import com.tencent.supersonic.knowledge.dictionary.MapResult; +import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary; +import com.tencent.supersonic.knowledge.service.SearchService; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; - -import com.tencent.supersonic.knowledge.dictionary.MapResult; -import com.tencent.supersonic.knowledge.dictionary.HadoopFileIOAdapter; -import com.tencent.supersonic.knowledge.service.SearchService; -import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -186,11 +185,11 @@ public class HanlpHelper { } } - public static void transLetterOriginal(List mapResults) { + public static void transLetterOriginal(List mapResults) { if (CollectionUtils.isEmpty(mapResults)) { return; } - for (MapResult mapResult : mapResults) { + for (T mapResult : mapResults) { if (MultiCustomDictionary.isLowerLetter(mapResult.getName())) { if (CustomDictionary.contains(mapResult.getName())) { CoreDictionary.Attribute attribute = CustomDictionary.get(mapResult.getName()); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java index 8d2e82195..114470079 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java @@ -1,18 +1,28 @@ package com.tencent.supersonic.common.util.embedding; +import com.tencent.supersonic.common.pojo.enums.DictWordType; import lombok.Data; import java.util.Map; +import org.apache.commons.lang3.StringUtils; @Data public class Retrieval { - private Long id; + protected String id; - private double distance; + protected double distance; - private String query; + protected String query; - private Map metadata; + protected Map metadata; + + public static Long getLongId(String id) { + if (StringUtils.isBlank(id)) { + return null; + } + String[] split = id.split(DictWordType.NATURE_SPILT); + return Long.parseLong(split[0]); + } } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java index 3d763dc9f..defb1fe6a 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserAddHelperTest.java @@ -264,7 +264,8 @@ class SqlParserAddHelperTest { + "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); - sql = "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11' and count(DISTINCT uv) >1 " + sql = "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11'" + + " and count(DISTINCT uv) >1 " + "GROUP BY department order by count(DISTINCT uv) desc limit 10"; replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); @@ -290,7 +291,8 @@ class SqlParserAddHelperTest { replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); Assert.assertEquals( - "SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND count(DISTINCT uv) > 1 " + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = " + + "'2023-09-11' AND count(DISTINCT uv) > 1 " + "AND department = 'HR' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); @@ -300,8 +302,10 @@ class SqlParserAddHelperTest { replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields); Assert.assertEquals( - "SELECT department, count(DISTINCT uv) FROM t_1 WHERE (count(DISTINCT uv) > 1 AND department = 'HR') AND " - + "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + "SELECT department, count(DISTINCT uv) FROM t_1 WHERE (count(DISTINCT uv) > " + + "1 AND department = 'HR') AND " + + "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY " + + "count(DISTINCT uv) DESC LIMIT 10", replaceSql); sql = "select department, count(DISTINCT uv) as uv from t_1 where sys_imp_date = '2023-09-11' GROUP BY " diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index dee362afd..abd09f4e7 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -1,4 +1,5 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\ + com.tencent.supersonic.chat.mapper.EmbeddingMapper, \ com.tencent.supersonic.chat.mapper.HanlpDictMapper, \ com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \ com.tencent.supersonic.chat.mapper.QueryFilterMapper, \ diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index f53c2908c..e7e93a249 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -1,4 +1,5 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\ + com.tencent.supersonic.chat.mapper.EmbeddingMapper, \ com.tencent.supersonic.chat.mapper.HanlpDictMapper, \ com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \ com.tencent.supersonic.chat.mapper.QueryFilterMapper, \ diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java index 07403b7bd..6491cd29a 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java @@ -2,18 +2,18 @@ package com.tencent.supersonic.semantic.model.domain.listener; import com.alibaba.fastjson.JSONObject; import com.tencent.supersonic.common.pojo.DataEvent; +import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.EventType; -import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; @Component @Slf4j @@ -30,15 +30,17 @@ public class MetaEmbeddingListener implements ApplicationListener { return; } List embeddingQueries = event.getDataItems() - .stream().filter(dataItem -> dataItem.getType().equals(TypeEnums.METRIC)).map(dataItem -> { - EmbeddingQuery embeddingQuery = new EmbeddingQuery(); - embeddingQuery.setQueryId(dataItem.getId().toString()); - embeddingQuery.setQuery(dataItem.getName()); - Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); - embeddingQuery.setMetadata(meta); - embeddingQuery.setQueryEmbedding(null); - return embeddingQuery; - }).collect(Collectors.toList()); + .stream() + .map(dataItem -> { + EmbeddingQuery embeddingQuery = new EmbeddingQuery(); + embeddingQuery.setQueryId( + dataItem.getId().toString() + DictWordType.NATURE_SPILT + dataItem.getType().getName()); + embeddingQuery.setQuery(dataItem.getName()); + Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); + embeddingQuery.setMetadata(meta); + embeddingQuery.setQueryEmbedding(null); + return embeddingQuery; + }).collect(Collectors.toList()); if (CollectionUtils.isEmpty(embeddingQueries)) { return; }