From a9232fa1c724c2e86a70fef7f91689a6f8884692 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sat, 3 Aug 2024 23:52:52 +0800 Subject: [PATCH] (improvement)(chat) Optimize and fix the NatureHelper code. (#1510) Co-authored-by: lexluo --- .../chat/knowledge/helper/NatureHelper.java | 187 +++++++----------- .../tencent/supersonic/headless/TagTest.java | 27 +++ 2 files changed, 98 insertions(+), 116 deletions(-) create mode 100644 launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagTest.java diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/NatureHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/NatureHelper.java index bfedb3229..9499c81a4 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/NatureHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/NatureHelper.java @@ -1,16 +1,11 @@ package com.tencent.supersonic.headless.chat.knowledge.helper; -import com.google.common.collect.Lists; import com.hankcs.hanlp.corpus.tag.Nature; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.util.CollectionUtils; - -import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -18,6 +13,9 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; /** * nature parse helper @@ -57,54 +55,35 @@ public class NatureHelper { } private static boolean isDataSetOrEntity(S2Term term, Integer model) { - return (DictWordType.NATURE_SPILT + model).equals(term.nature.toString()) || term.nature.toString() - .endsWith(DictWordType.ENTITY.getType()); + String natureStr = term.nature.toString(); + return (DictWordType.NATURE_SPILT + model).equals(natureStr) || natureStr.endsWith( + DictWordType.ENTITY.getType()); } public static Integer getDataSetByNature(Nature nature) { - if (nature.startsWith(DictWordType.NATURE_SPILT)) { - String[] dimensionValues = nature.toString().split(DictWordType.NATURE_SPILT); - if (StringUtils.isNumeric(dimensionValues[1])) { - return Integer.valueOf(dimensionValues[1]); - } + if (!nature.startsWith(DictWordType.NATURE_SPILT)) { + return 0; } - return 0; + String[] dimensionValues = nature.toString().split(DictWordType.NATURE_SPILT); + return StringUtils.isNumeric(dimensionValues[1]) ? Integer.valueOf(dimensionValues[1]) : 0; } public static Long getDataSetId(String nature) { - try { - String[] split = nature.split(DictWordType.NATURE_SPILT); - if (split.length <= 1) { - return null; - } - return Long.valueOf(split[1]); - } catch (NumberFormatException e) { - log.error("", e); - } - return null; + return parseIdFromNature(nature, 1); } private static Long getModelId(String nature) { - try { - String[] split = nature.split(DictWordType.NATURE_SPILT); - if (split.length <= 1) { - return null; - } - return Long.valueOf(split[1]); - } catch (NumberFormatException e) { - log.error("", e); - } - return null; + return parseIdFromNature(nature, 1); } - private static Nature changeModel2DataSet(String nature, Long dataSetId) { + private static String changeModel2DataSet(String nature, Long dataSetId) { try { String[] split = nature.split(DictWordType.NATURE_SPILT); if (split.length <= 1) { return null; } split[1] = String.valueOf(dataSetId); - return Nature.create(StringUtils.join(split, DictWordType.NATURE_SPILT)); + return String.join(DictWordType.NATURE_SPILT, split); } catch (NumberFormatException e) { log.error("", e); } @@ -112,47 +91,28 @@ public class NatureHelper { } public static List changeModel2DataSet(String nature, Map> modelIdToDataSetIds) { - //term prefix id is dataSetId, no need to transform if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) { - return Lists.newArrayList(nature); + return Collections.singletonList(nature); } Long modelId = getModelId(nature); List dataSetIds = modelIdToDataSetIds.get(modelId); if (CollectionUtils.isEmpty(dataSetIds)) { - return Lists.newArrayList(); + return Collections.emptyList(); } - return dataSetIds.stream().map(dataSetId -> String.valueOf(changeModel2DataSet(nature, dataSetId))) + return dataSetIds.stream() + .map(dataSetId -> changeModel2DataSet(nature, dataSetId)) + .filter(Objects::nonNull) + .map(String::valueOf) .collect(Collectors.toList()); } public static boolean isDimensionValueDataSetId(String nature) { - if (StringUtils.isEmpty(nature)) { - return false; - } - if (!nature.startsWith(DictWordType.NATURE_SPILT)) { - return false; - } - String[] split = nature.split(DictWordType.NATURE_SPILT); - if (split.length <= 1) { - return false; - } - return !nature.endsWith(DictWordType.METRIC.getType()) && !nature.endsWith( - DictWordType.DIMENSION.getType()) && !nature.endsWith(DictWordType.TERM.getType()) - && StringUtils.isNumeric(split[1]); + return isNatureValid(nature) && !isNatureType(nature, DictWordType.METRIC, DictWordType.DIMENSION, + DictWordType.TERM) && StringUtils.isNumeric(nature.split(DictWordType.NATURE_SPILT)[1]); } public static boolean isTermNature(String nature) { - if (StringUtils.isEmpty(nature)) { - return false; - } - if (!nature.startsWith(DictWordType.NATURE_SPILT)) { - return false; - } - String[] split = nature.split(DictWordType.NATURE_SPILT); - if (split.length <= 1) { - return false; - } - return nature.endsWith(DictWordType.TERM.getType()); + return isNatureValid(nature) && nature.endsWith(DictWordType.TERM.getType()); } public static DataSetInfoStat getDataSetStat(List terms) { @@ -182,72 +142,67 @@ public class NatureHelper { .endsWith(DictWordType.METRIC.getType())).count(); } - /** - * Get the number of types of class parts of speech - * modelId -> (nature , natureCount) - * - * @param terms - * @return - */ public static Map> getDataSetToNatureStat(List terms) { Map> modelToNature = new HashMap<>(); - terms.stream().filter( - term -> term.nature.startsWith(DictWordType.NATURE_SPILT) - ).forEach(term -> { - DictWordType dictWordType = DictWordType.getNatureType(String.valueOf(term.nature)); - Long model = getDataSetId(String.valueOf(term.nature)); + terms.stream() + .filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT)) + .forEach(term -> { + DictWordType dictWordType = DictWordType.getNatureType(term.nature.toString()); + Long model = getDataSetId(term.nature.toString()); - Map natureTypeMap = new HashMap<>(); - natureTypeMap.put(dictWordType, 1); - - Map original = modelToNature.get(model); - if (Objects.isNull(original)) { - modelToNature.put(model, natureTypeMap); - } else { - Integer count = original.get(dictWordType); - if (Objects.isNull(count)) { - count = 1; - } else { - count = count + 1; - } - original.put(dictWordType, count); - } - }); + modelToNature.computeIfAbsent(model, k -> new HashMap<>()) + .merge(dictWordType, 1, Integer::sum); + }); return modelToNature; } public static List selectPossibleDataSets(List terms) { Map> modelToNatureStat = getDataSetToNatureStat(terms); - Integer maxDataSetTypeSize = modelToNatureStat.entrySet().stream() - .max(Comparator.comparingInt(o -> o.getValue().size())).map(entry -> entry.getValue().size()) - .orElse(null); - if (Objects.isNull(maxDataSetTypeSize) || maxDataSetTypeSize == 0) { - return new ArrayList<>(); - } - return modelToNatureStat.entrySet().stream().filter(entry -> entry.getValue().size() == maxDataSetTypeSize) - .map(entry -> entry.getKey()).collect(Collectors.toList()); + return modelToNatureStat.entrySet().stream() + .max(Comparator.comparingInt(entry -> entry.getValue().size())) + .map(entry -> modelToNatureStat.entrySet().stream() + .filter(e -> e.getValue().size() == entry.getValue().size()) + .map(Map.Entry::getKey) + .collect(Collectors.toList())) + .orElse(Collections.emptyList()); } public static Long getElementID(String nature) { - String[] split = nature.split(DictWordType.NATURE_SPILT); - if (split.length >= 3) { - return Long.valueOf(split[2]); - } - return 0L; + return parseIdFromNature(nature, 2); } public static Set getModelIds(Map> modelIdToDataSetIds, Set detectDataSetIds) { - Set detectModelIds = modelIdToDataSetIds.keySet(); - if (!CollectionUtils.isEmpty(detectDataSetIds)) { - detectModelIds = modelIdToDataSetIds.entrySet().stream().filter(entry -> { - List dataSetIds = entry.getValue().stream().filter(detectDataSetIds::contains) - .collect(Collectors.toList()); - if (!CollectionUtils.isEmpty(dataSetIds)) { - return true; - } - return false; - }).map(entry -> entry.getKey()).collect(Collectors.toSet()); + if (CollectionUtils.isEmpty(detectDataSetIds)) { + return modelIdToDataSetIds.keySet(); } - return detectModelIds; + return modelIdToDataSetIds.entrySet().stream() + .filter(entry -> !Collections.disjoint(entry.getValue(), detectDataSetIds)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + } + + public static Long parseIdFromNature(String nature, int index) { + try { + String[] split = nature.split(DictWordType.NATURE_SPILT); + if (split.length > index) { + return Long.valueOf(split[index]); + } + } catch (NumberFormatException e) { + log.error("Error parsing long from nature: {}", nature, e); + } + return null; + } + + private static boolean isNatureValid(String nature) { + return StringUtils.isNotEmpty(nature) && nature.startsWith(DictWordType.NATURE_SPILT); + } + + private static boolean isNatureType(String nature, DictWordType... types) { + for (DictWordType type : types) { + if (nature.endsWith(type.getType())) { + return true; + } + } + return false; } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagTest.java new file mode 100644 index 000000000..2cf885617 --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagTest.java @@ -0,0 +1,27 @@ +package com.tencent.supersonic.headless; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.request.ItemValueReq; +import com.tencent.supersonic.headless.api.pojo.response.ItemValueResp; +import com.tencent.supersonic.headless.server.service.TagQueryService; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; +import org.springframework.beans.factory.annotation.Autowired; + +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +public class TagTest extends BaseTest { + + @Autowired + private TagQueryService tagQueryService; + + @Test + public void testQueryTagValue() throws Exception { + ItemValueReq itemValueReq = new ItemValueReq(); + itemValueReq.setId(1L); + ItemValueResp itemValueResp = tagQueryService.queryTagValue(itemValueReq, User.getFakeUser()); + Assertions.assertNotNull(itemValueResp); + } + +} \ No newline at end of file