From 2ed776221086823017a2e812a66270d691b18e12 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Thu, 6 Jun 2024 10:10:40 +0800 Subject: [PATCH] (improvement)(Headless) Filtering based on dataSetIds during Mapper detection Compatible with term (#1096) Co-authored-by: jolunoluo --- .../com/hankcs/hanlp/LoadRemoveService.java | 15 +----- .../collection/trie/bintrie/BaseNode.java | 21 ++++---- .../core/chat/knowledge/SearchService.java | 50 +++++++++++-------- .../tencent/supersonic/demo/S2VisitsDemo.java | 1 - 4 files changed, 40 insertions(+), 47 deletions(-) diff --git a/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java b/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java index 745555398..b31e87d28 100644 --- a/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java +++ b/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java @@ -10,7 +10,6 @@ import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.Set; @Data @Slf4j @@ -19,23 +18,11 @@ public class LoadRemoveService { @Value("${mapper.remove.nature.prefix:}") private String mapperRemoveNaturePrefix; - public List removeNatures(List value, Set detectModelIds) { + public List removeNatures(List value) { if (CollectionUtils.isEmpty(value)) { return value; } List resultList = new ArrayList<>(value); - if (!CollectionUtils.isEmpty(detectModelIds)) { - resultList.removeIf(nature -> { - if (Objects.isNull(nature)) { - return false; - } - Long modelId = getDataSetId(nature); - if (Objects.nonNull(modelId)) { - return !detectModelIds.contains(modelId); - } - return false; - }); - } if (StringUtils.isNotBlank(mapperRemoveNaturePrefix)) { resultList.removeIf(nature -> { if (Objects.isNull(nature)) { diff --git a/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java b/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java index 0f50c76ba..53d61d033 100644 --- a/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java +++ b/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java @@ -2,6 +2,9 @@ package com.hankcs.hanlp.collection.trie.bintrie; import com.hankcs.hanlp.LoadRemoveService; import com.hankcs.hanlp.corpus.io.ByteArray; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.DataOutputStream; import java.io.IOException; import java.io.ObjectInput; @@ -14,8 +17,6 @@ import java.util.Map; import java.util.Objects; import java.util.Queue; import java.util.Set; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public abstract class BaseNode implements Comparable { @@ -286,12 +287,12 @@ public abstract class BaseNode implements Comparable { + '}'; } - public void walkNode(Set> entrySet, Set detectModelIds) { + public void walkNode(Set> entrySet) { if (status == Status.WORD_MIDDLE_2 || status == Status.WORD_END_3) { - logger.debug("detectModelIds:{},before:{}", detectModelIds, value.toString()); - List natures = new LoadRemoveService().removeNatures((List) value, detectModelIds); + logger.debug("walkNode before:{}", value.toString()); + List natures = new LoadRemoveService().removeNatures((List) value); String name = this.prefix != null ? this.prefix + c : "" + c; - logger.debug("name:{},after:{},natures:{}", name, (List) value, natures); + logger.debug("walkNode name:{},after:{},natures:{}", name, (List) value, natures); entrySet.add(new TrieEntry(name, (V) natures)); } } @@ -300,21 +301,17 @@ public abstract class BaseNode implements Comparable { * walk limit * @param sb * @param entrySet - * @param limit */ - public void walkLimit(StringBuilder sb, Set> entrySet, int limit, Set detectModelIds) { + public void walkLimit(StringBuilder sb, Set> entrySet) { Queue queue = new ArrayDeque<>(); this.prefix = sb.toString(); queue.add(this); while (!queue.isEmpty()) { - if (entrySet.size() >= limit) { - break; - } BaseNode root = queue.poll(); if (root == null) { continue; } - root.walkNode(entrySet, detectModelIds); + root.walkNode(entrySet); if (root.child == null) { continue; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/SearchService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/SearchService.java index 4174b82f1..6ef88e8e1 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/SearchService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/SearchService.java @@ -48,22 +48,16 @@ public class SearchService { public static List prefixSearch(String key, int limit, BinTrie> binTrie, Map> modelIdToDataSetIds, Set detectDataSetIds) { - Set>> result = prefixSearchLimit(key, limit, binTrie, - modelIdToDataSetIds, detectDataSetIds); + Set>> result = search(key, binTrie); List hanlpMapResults = result.stream().map( entry -> { String name = entry.getKey().replace("#", " "); return new HanlpMapResult(name, entry.getValue(), key); } ).sorted((a, b) -> -(b.getName().length() - a.getName().length())) - .limit(SEARCH_SIZE) .collect(Collectors.toList()); - for (HanlpMapResult hanlpMapResult : hanlpMapResults) { - List natures = hanlpMapResult.getNatures().stream() - .map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds)) - .flatMap(Collection::stream).collect(Collectors.toList()); - hanlpMapResult.setNatures(natures); - } + hanlpMapResults = transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, + detectDataSetIds, limit); return hanlpMapResults; } @@ -80,11 +74,8 @@ public class SearchService { public static List suffixSearch(String key, int limit, BinTrie> binTrie, Map> modelIdToDataSetIds, Set detectDataSetIds) { - - Set>> result = prefixSearchLimit(key, limit, binTrie, modelIdToDataSetIds, - detectDataSetIds); - - return result.stream().map( + Set>> result = search(key, binTrie); + List hanlpMapResults = result.stream().map( entry -> { String name = entry.getKey().replace("#", " "); List natures = entry.getValue().stream() @@ -94,15 +85,34 @@ public class SearchService { return new HanlpMapResult(name, natures, key); } ).sorted((a, b) -> -(b.getName().length() - a.getName().length())) - .limit(SEARCH_SIZE) .collect(Collectors.toList()); + return transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit); } - private static Set>> prefixSearchLimit(String key, int limit, - BinTrie> binTrie, Map> modelIdToDataSetIds, Set detectDataSetIds) { - - Set detectModelIds = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds); + private static List transformAndFilterByDataSet(List hanlpMapResults, + Map> modelIdToDataSetIds, + Set detectDataSetIds, int limit) { + return hanlpMapResults.stream().peek(hanlpMapResult -> { + List natures = hanlpMapResult.getNatures().stream() + .map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds)) + .flatMap(Collection::stream) + .filter(nature -> { + if (CollectionUtils.isEmpty(detectDataSetIds)) { + return true; + } + Long dataSetId = NatureHelper.getDataSetId(nature); + if (dataSetId != null) { + return detectDataSetIds.contains(dataSetId); + } + return false; + }).collect(Collectors.toList()); + hanlpMapResult.setNatures(natures); + }).filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures())) + .limit(limit).collect(Collectors.toList()); + } + private static Set>> search(String key, + BinTrie> binTrie) { key = key.toLowerCase(); Set>> entrySet = new TreeSet>>(); @@ -122,7 +132,7 @@ public class SearchService { if (branch == null) { return entrySet; } - branch.walkLimit(sb, entrySet, limit, detectModelIds); + branch.walkLimit(sb, entrySet); return entrySet; } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index e489628c1..15139196d 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -16,7 +16,6 @@ import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.build.WebBase; import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.ModelRela; -import com.tencent.supersonic.common.pojo.SystemConfig; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;