(improvement)(chat) Optimize and fix the NatureHelper code. (#1510)

Co-authored-by: lexluo <lexluo@tencent.com>
This commit is contained in:
lexluo09
2024-08-03 23:52:52 +08:00
committed by GitHub
parent ac6b28ebb7
commit a9232fa1c7
2 changed files with 98 additions and 116 deletions

View File

@@ -1,16 +1,11 @@
package com.tencent.supersonic.headless.chat.knowledge.helper; package com.tencent.supersonic.headless.chat.knowledge.helper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.corpus.tag.Nature; import com.hankcs.hanlp.corpus.tag.Nature;
import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat; import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat;
import lombok.extern.slf4j.Slf4j; import java.util.Collections;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@@ -18,6 +13,9 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
/** /**
* nature parse helper * nature parse helper
@@ -57,54 +55,35 @@ public class NatureHelper {
} }
private static boolean isDataSetOrEntity(S2Term term, Integer model) { private static boolean isDataSetOrEntity(S2Term term, Integer model) {
return (DictWordType.NATURE_SPILT + model).equals(term.nature.toString()) || term.nature.toString() String natureStr = term.nature.toString();
.endsWith(DictWordType.ENTITY.getType()); return (DictWordType.NATURE_SPILT + model).equals(natureStr) || natureStr.endsWith(
DictWordType.ENTITY.getType());
} }
public static Integer getDataSetByNature(Nature nature) { public static Integer getDataSetByNature(Nature nature) {
if (nature.startsWith(DictWordType.NATURE_SPILT)) { if (!nature.startsWith(DictWordType.NATURE_SPILT)) {
String[] dimensionValues = nature.toString().split(DictWordType.NATURE_SPILT); return 0;
if (StringUtils.isNumeric(dimensionValues[1])) {
return Integer.valueOf(dimensionValues[1]);
}
} }
return 0; String[] dimensionValues = nature.toString().split(DictWordType.NATURE_SPILT);
return StringUtils.isNumeric(dimensionValues[1]) ? Integer.valueOf(dimensionValues[1]) : 0;
} }
public static Long getDataSetId(String nature) { public static Long getDataSetId(String nature) {
try { return parseIdFromNature(nature, 1);
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {
return null;
}
return Long.valueOf(split[1]);
} catch (NumberFormatException e) {
log.error("", e);
}
return null;
} }
private static Long getModelId(String nature) { private static Long getModelId(String nature) {
try { return parseIdFromNature(nature, 1);
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {
return null;
}
return Long.valueOf(split[1]);
} catch (NumberFormatException e) {
log.error("", e);
}
return null;
} }
private static Nature changeModel2DataSet(String nature, Long dataSetId) { private static String changeModel2DataSet(String nature, Long dataSetId) {
try { try {
String[] split = nature.split(DictWordType.NATURE_SPILT); String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) { if (split.length <= 1) {
return null; return null;
} }
split[1] = String.valueOf(dataSetId); split[1] = String.valueOf(dataSetId);
return Nature.create(StringUtils.join(split, DictWordType.NATURE_SPILT)); return String.join(DictWordType.NATURE_SPILT, split);
} catch (NumberFormatException e) { } catch (NumberFormatException e) {
log.error("", e); log.error("", e);
} }
@@ -112,47 +91,28 @@ public class NatureHelper {
} }
public static List<String> changeModel2DataSet(String nature, Map<Long, List<Long>> modelIdToDataSetIds) { public static List<String> changeModel2DataSet(String nature, Map<Long, List<Long>> modelIdToDataSetIds) {
//term prefix id is dataSetId, no need to transform
if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) { if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) {
return Lists.newArrayList(nature); return Collections.singletonList(nature);
} }
Long modelId = getModelId(nature); Long modelId = getModelId(nature);
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId); List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
if (CollectionUtils.isEmpty(dataSetIds)) { if (CollectionUtils.isEmpty(dataSetIds)) {
return Lists.newArrayList(); return Collections.emptyList();
} }
return dataSetIds.stream().map(dataSetId -> String.valueOf(changeModel2DataSet(nature, dataSetId))) return dataSetIds.stream()
.map(dataSetId -> changeModel2DataSet(nature, dataSetId))
.filter(Objects::nonNull)
.map(String::valueOf)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public static boolean isDimensionValueDataSetId(String nature) { public static boolean isDimensionValueDataSetId(String nature) {
if (StringUtils.isEmpty(nature)) { return isNatureValid(nature) && !isNatureType(nature, DictWordType.METRIC, DictWordType.DIMENSION,
return false; DictWordType.TERM) && StringUtils.isNumeric(nature.split(DictWordType.NATURE_SPILT)[1]);
}
if (!nature.startsWith(DictWordType.NATURE_SPILT)) {
return false;
}
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {
return false;
}
return !nature.endsWith(DictWordType.METRIC.getType()) && !nature.endsWith(
DictWordType.DIMENSION.getType()) && !nature.endsWith(DictWordType.TERM.getType())
&& StringUtils.isNumeric(split[1]);
} }
public static boolean isTermNature(String nature) { public static boolean isTermNature(String nature) {
if (StringUtils.isEmpty(nature)) { return isNatureValid(nature) && nature.endsWith(DictWordType.TERM.getType());
return false;
}
if (!nature.startsWith(DictWordType.NATURE_SPILT)) {
return false;
}
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {
return false;
}
return nature.endsWith(DictWordType.TERM.getType());
} }
public static DataSetInfoStat getDataSetStat(List<S2Term> terms) { public static DataSetInfoStat getDataSetStat(List<S2Term> terms) {
@@ -182,72 +142,67 @@ public class NatureHelper {
.endsWith(DictWordType.METRIC.getType())).count(); .endsWith(DictWordType.METRIC.getType())).count();
} }
/**
* Get the number of types of class parts of speech
* modelId -> (nature , natureCount)
*
* @param terms
* @return
*/
public static Map<Long, Map<DictWordType, Integer>> getDataSetToNatureStat(List<S2Term> terms) { public static Map<Long, Map<DictWordType, Integer>> getDataSetToNatureStat(List<S2Term> terms) {
Map<Long, Map<DictWordType, Integer>> modelToNature = new HashMap<>(); Map<Long, Map<DictWordType, Integer>> modelToNature = new HashMap<>();
terms.stream().filter( terms.stream()
term -> term.nature.startsWith(DictWordType.NATURE_SPILT) .filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT))
).forEach(term -> { .forEach(term -> {
DictWordType dictWordType = DictWordType.getNatureType(String.valueOf(term.nature)); DictWordType dictWordType = DictWordType.getNatureType(term.nature.toString());
Long model = getDataSetId(String.valueOf(term.nature)); Long model = getDataSetId(term.nature.toString());
Map<DictWordType, Integer> natureTypeMap = new HashMap<>(); modelToNature.computeIfAbsent(model, k -> new HashMap<>())
natureTypeMap.put(dictWordType, 1); .merge(dictWordType, 1, Integer::sum);
});
Map<DictWordType, Integer> original = modelToNature.get(model);
if (Objects.isNull(original)) {
modelToNature.put(model, natureTypeMap);
} else {
Integer count = original.get(dictWordType);
if (Objects.isNull(count)) {
count = 1;
} else {
count = count + 1;
}
original.put(dictWordType, count);
}
});
return modelToNature; return modelToNature;
} }
public static List<Long> selectPossibleDataSets(List<S2Term> terms) { public static List<Long> selectPossibleDataSets(List<S2Term> terms) {
Map<Long, Map<DictWordType, Integer>> modelToNatureStat = getDataSetToNatureStat(terms); Map<Long, Map<DictWordType, Integer>> modelToNatureStat = getDataSetToNatureStat(terms);
Integer maxDataSetTypeSize = modelToNatureStat.entrySet().stream() return modelToNatureStat.entrySet().stream()
.max(Comparator.comparingInt(o -> o.getValue().size())).map(entry -> entry.getValue().size()) .max(Comparator.comparingInt(entry -> entry.getValue().size()))
.orElse(null); .map(entry -> modelToNatureStat.entrySet().stream()
if (Objects.isNull(maxDataSetTypeSize) || maxDataSetTypeSize == 0) { .filter(e -> e.getValue().size() == entry.getValue().size())
return new ArrayList<>(); .map(Map.Entry::getKey)
} .collect(Collectors.toList()))
return modelToNatureStat.entrySet().stream().filter(entry -> entry.getValue().size() == maxDataSetTypeSize) .orElse(Collections.emptyList());
.map(entry -> entry.getKey()).collect(Collectors.toList());
} }
public static Long getElementID(String nature) { public static Long getElementID(String nature) {
String[] split = nature.split(DictWordType.NATURE_SPILT); return parseIdFromNature(nature, 2);
if (split.length >= 3) {
return Long.valueOf(split[2]);
}
return 0L;
} }
public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) { public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Long> detectModelIds = modelIdToDataSetIds.keySet(); if (CollectionUtils.isEmpty(detectDataSetIds)) {
if (!CollectionUtils.isEmpty(detectDataSetIds)) { return modelIdToDataSetIds.keySet();
detectModelIds = modelIdToDataSetIds.entrySet().stream().filter(entry -> {
List<Long> dataSetIds = entry.getValue().stream().filter(detectDataSetIds::contains)
.collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dataSetIds)) {
return true;
}
return false;
}).map(entry -> entry.getKey()).collect(Collectors.toSet());
} }
return detectModelIds; return modelIdToDataSetIds.entrySet().stream()
.filter(entry -> !Collections.disjoint(entry.getValue(), detectDataSetIds))
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
}
public static Long parseIdFromNature(String nature, int index) {
try {
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length > index) {
return Long.valueOf(split[index]);
}
} catch (NumberFormatException e) {
log.error("Error parsing long from nature: {}", nature, e);
}
return null;
}
private static boolean isNatureValid(String nature) {
return StringUtils.isNotEmpty(nature) && nature.startsWith(DictWordType.NATURE_SPILT);
}
private static boolean isNatureType(String nature, DictWordType... types) {
for (DictWordType type : types) {
if (nature.endsWith(type.getType())) {
return true;
}
}
return false;
} }
} }

View File

@@ -0,0 +1,27 @@
package com.tencent.supersonic.headless;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.ItemValueReq;
import com.tencent.supersonic.headless.api.pojo.response.ItemValueResp;
import com.tencent.supersonic.headless.server.service.TagQueryService;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
import org.springframework.beans.factory.annotation.Autowired;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
public class TagTest extends BaseTest {
@Autowired
private TagQueryService tagQueryService;
@Test
public void testQueryTagValue() throws Exception {
ItemValueReq itemValueReq = new ItemValueReq();
itemValueReq.setId(1L);
ItemValueResp itemValueResp = tagQueryService.queryTagValue(itemValueReq, User.getFakeUser());
Assertions.assertNotNull(itemValueResp);
}
}