(improvement)(Headless) Avoid term and dimension value conflicts (#1026)

This commit is contained in:
LXW
2024-05-22 18:38:30 +08:00
committed by GitHub
parent 418abef982
commit 987154c4a3
4 changed files with 44 additions and 4 deletions

View File

@@ -141,6 +141,20 @@ public class NatureHelper {
&& StringUtils.isNumeric(split[1]); && StringUtils.isNumeric(split[1]);
} }
public static boolean isTermNature(String nature) {
if (StringUtils.isEmpty(nature)) {
return false;
}
if (!nature.startsWith(DictWordType.NATURE_SPILT)) {
return false;
}
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {
return false;
}
return nature.endsWith(DictWordType.TERM.getType());
}
public static DataSetInfoStat getDataSetStat(List<S2Term> terms) { public static DataSetInfoStat getDataSetStat(List<S2Term> terms) {
return DataSetInfoStat.builder() return DataSetInfoStat.builder()
.dataSetCount(getDataSetCount(terms)) .dataSetCount(getDataSetCount(terms))

View File

@@ -10,11 +10,13 @@ import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -43,7 +45,7 @@ public abstract class BaseMapper implements SchemaMapper {
} }
private void filter(QueryContext queryContext) { private void filter(QueryContext queryContext) {
filterByDataSetId(queryContext);
switch (queryContext.getQueryDataType()) { switch (queryContext.getQueryDataType()) {
case TAG: case TAG:
filterByQueryDataType(queryContext, element -> !(element.getIsTag() > 0)); filterByQueryDataType(queryContext, element -> !(element.getIsTag() > 0));
@@ -62,7 +64,19 @@ public abstract class BaseMapper implements SchemaMapper {
default: default:
break; break;
} }
}
private static void filterByDataSetId(QueryContext queryContext) {
Set<Long> dataSetIds = queryContext.getDataSetIds();
if (CollectionUtils.isEmpty(dataSetIds)) {
return;
}
Set<Long> dataSetIdInMapInfo = queryContext.getMapInfo().getDataSetElementMatches().keySet();
for (Long dataSetId : dataSetIdInMapInfo) {
if (!dataSetIds.contains(dataSetId)) {
queryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId);
}
}
} }
private static void filterByQueryDataType(QueryContext queryContext, Predicate<SchemaElement> needRemovePredicate) { private static void filterByQueryDataType(QueryContext queryContext, Predicate<SchemaElement> needRemovePredicate) {

View File

@@ -2,10 +2,10 @@ package com.tencent.supersonic.headless.core.chat.mapper;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@@ -108,9 +108,12 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize(); Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize) List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
.collect(Collectors.toList()); .collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(dimensionMetrics)) { if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
oneRoundResults = dimensionMetrics; oneRoundResults = dimensionMetrics;
List<HanlpMapResult> termOneRoundResults = hanlpMapResults.stream()
.filter(hanlpMapResult -> mapperHelper.existTerms(hanlpMapResult.getNatures()))
.collect(Collectors.toList());
oneRoundResults.addAll(termOneRoundResults);
} }
// step6. select mapResul in one round // step6. select mapResul in one round
selectResultInOneRound(existResults, oneRoundResults); selectResultInOneRound(existResults, oneRoundResults);

View File

@@ -59,6 +59,15 @@ public class MapperHelper {
return false; return false;
} }
public boolean existTerms(List<String> natures) {
for (String nature : natures) {
if (NatureHelper.isTermNature(nature)) {
return true;
}
}
return false;
}
/*** /***
* get similarity * get similarity
* @param detectSegment * @param detectSegment