From 879696e4931593ef2dd75c85210725177e00923b Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Tue, 10 Sep 2024 22:58:34 +0800 Subject: [PATCH] [improvement](chat) Optimize the Schema Mapping rules (#1649) --- .../embedding/EmbeddingRecallRecognizer.java | 10 +- .../execute/MetricRecommendProcessor.java | 2 +- .../service/impl/EmbeddingServiceImpl.java | 4 +- .../store/embedding/Retrieval.java | 6 +- .../headless/api/pojo/SchemaElementMatch.java | 2 +- .../chat/knowledge/EmbeddingResult.java | 3 - .../chat/knowledge/HanlpMapResult.java | 7 +- .../headless/chat/knowledge/MapResult.java | 7 +- .../chat/knowledge/SearchService.java | 9 +- .../chat/knowledge/helper/HanlpHelper.java | 7 +- .../headless/chat/mapper/BaseMapper.java | 31 +--- .../chat/mapper/BaseMatchStrategy.java | 2 +- .../chat/mapper/DatabaseMatchStrategy.java | 7 +- .../headless/chat/mapper/EmbeddingMapper.java | 2 +- .../chat/mapper/EmbeddingMatchStrategy.java | 23 ++- .../chat/mapper/HanlpDictMatchStrategy.java | 19 +-- .../headless/chat/mapper/KeywordMapper.java | 4 +- .../headless/chat/mapper/MapFilter.java | 136 +++++++++++++++--- .../headless/chat/mapper/MapperHelper.java | 16 --- .../chat/utils/EditDistanceUtils.java | 27 ++++ 20 files changed, 202 insertions(+), 122 deletions(-) create mode 100644 headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/EditDistanceUtils.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java index a32318fdf..e0b1e3f1e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java @@ -51,13 +51,13 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer { continue; } plugin.setParseMode(ParseMode.EMBEDDING_RECALL); - double distance = embeddingRetrieval.getDistance(); - double score = parseContext.getQueryText().length() * (1 - distance); + double similarity = embeddingRetrieval.getSimilarity(); + double score = parseContext.getQueryText().length() * similarity; return PluginRecallResult.builder() .plugin(plugin) .dataSetIds(dataSetList) .score(score) - .distance(distance) + .distance(similarity) .build(); } } @@ -73,7 +73,9 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer { if (!CollectionUtils.isEmpty(embeddingRetrievals)) { embeddingRetrievals = embeddingRetrievals.stream() - .sorted(Comparator.comparingDouble(o -> Math.abs(o.getDistance()))) + .sorted( + Comparator.comparingDouble( + o -> Math.abs(o.getSimilarity()))) .collect(Collectors.toList()); embeddingResp.setRetrieval(embeddingRetrievals); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java index e40359dac..a0e1e6480 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java @@ -65,7 +65,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor { List retrievals = retrieveQueryResults.stream() .flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()) - .sorted(Comparator.comparingDouble(Retrieval::getDistance)) + .sorted(Comparator.comparingDouble(Retrieval::getSimilarity)) .distinct() .collect(Collectors.toList()); Set metricIds = diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 316980d8a..6d46ce824 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -164,7 +164,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { List retrievals = result.matches().stream() .map(this::convertToRetrieval) - .sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed()) + .sorted(Comparator.comparingDouble(Retrieval::getSimilarity)) .limit(num) .collect(Collectors.toList()); @@ -177,7 +177,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { private Retrieval convertToRetrieval(EmbeddingMatch embeddingMatch) { Retrieval retrieval = new Retrieval(); TextSegment embedded = embeddingMatch.embedded(); - retrieval.setDistance(1 - embeddingMatch.score()); + retrieval.setSimilarity(embeddingMatch.score()); retrieval.setId(TextSegmentConvert.getQueryId(embedded)); retrieval.setQuery(embedded.text()); diff --git a/common/src/main/java/dev/langchain4j/store/embedding/Retrieval.java b/common/src/main/java/dev/langchain4j/store/embedding/Retrieval.java index 6a04df504..f8dc386c0 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/Retrieval.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/Retrieval.java @@ -12,7 +12,7 @@ public class Retrieval { protected String id; - protected double distance; + protected double similarity; protected String query; @@ -35,7 +35,7 @@ public class Retrieval { return false; } Retrieval retrieval = (Retrieval) o; - return Double.compare(retrieval.distance, distance) == 0 + return Double.compare(retrieval.similarity, similarity) == 0 && Objects.equal(id, retrieval.id) && Objects.equal(query, retrieval.query) && Objects.equal(metadata, retrieval.metadata); @@ -43,6 +43,6 @@ public class Retrieval { @Override public int hashCode() { - return Objects.hashCode(id, distance, query, metadata); + return Objects.hashCode(id, similarity, query, metadata); } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java index 2509961e8..a0431169c 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java @@ -12,8 +12,8 @@ import lombok.ToString; @AllArgsConstructor @NoArgsConstructor public class SchemaElementMatch { - SchemaElement element; + double offset; double similarity; String detectWord; String word; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java index 282a59063..569892af9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java @@ -12,9 +12,6 @@ import java.util.Map; public class EmbeddingResult extends MapResult { private String id; - - private double distance; - private Map metadata; @Override diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java index 9d8789a30..612dc3723 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java @@ -10,16 +10,13 @@ import java.util.List; @Data @ToString public class HanlpMapResult extends MapResult { - private List natures; - private int offset = 0; - private double similarity; - - public HanlpMapResult(String name, List natures, String detectWord) { + public HanlpMapResult(String name, List natures, String detectWord, double similarity) { this.name = name; this.natures = natures; this.detectWord = detectWord; + this.similarity = similarity; } @Override diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MapResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MapResult.java index 616bfaf92..84ecf6254 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MapResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MapResult.java @@ -10,16 +10,17 @@ import java.io.Serializable; public abstract class MapResult implements Serializable { protected String name; + protected int offset; - protected int index; protected String detectWord; + protected double similarity; + public abstract String getMapKey(); public Boolean lessSimilar(MapResult otherResult) { String mapKey = this.getMapKey(); String otherMapKey = otherResult.getMapKey(); - return mapKey.equals(otherMapKey) - && otherResult.getDetectWord().length() < otherResult.getDetectWord().length(); + return mapKey.equals(otherMapKey) && otherResult.similarity < otherResult.similarity; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/SearchService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/SearchService.java index c72dde3b5..9c688fc99 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/SearchService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/SearchService.java @@ -8,6 +8,7 @@ import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; +import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -62,7 +63,9 @@ public class SearchService { .map( entry -> { String name = entry.getKey().replace("#", " "); - return new HanlpMapResult(name, entry.getValue(), key); + double similarity = EditDistanceUtils.getSimilarity(name, key); + return new HanlpMapResult( + name, entry.getValue(), key, similarity); }) .sorted((a, b) -> -(b.getName().length() - a.getName().length())) .collect(Collectors.toList()); @@ -109,8 +112,10 @@ public class SearchService { .getType(), "")) .collect(Collectors.toList()); + name = StringUtils.reverse(name); - return new HanlpMapResult(name, natures, key); + double similarity = EditDistanceUtils.getSimilarity(name, key); + return new HanlpMapResult(name, natures, key, similarity); }) .sorted((a, b) -> -(b.getName().length() - a.getName().length())) .collect(Collectors.toList()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java index 28aaa12c3..ee73409c3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java @@ -270,7 +270,10 @@ public class HanlpHelper { if (orig != null) { MapResult addMapResult = new HanlpMapResult( - orig, Arrays.asList(nature), hanlpMapResult.getDetectWord()); + orig, + Arrays.asList(nature), + hanlpMapResult.getDetectWord(), + hanlpMapResult.getSimilarity()); mapResults.add((T) addMapResult); isAdd = true; } @@ -301,7 +304,7 @@ public class HanlpHelper { addMapResult.setDetectWord(embeddingResult.getDetectWord()); addMapResult.setId(embeddingResult.getId()); addMapResult.setMetadata(embeddingResult.getMetadata()); - addMapResult.setDistance(embeddingResult.getDistance()); + addMapResult.setSimilarity(embeddingResult.getSimilarity()); mapResults.add((T) addMapResult); isAdd = true; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java index 77e825017..cec1f32f8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java @@ -37,7 +37,7 @@ public abstract class BaseMapper implements SchemaMapper { try { doMap(chatQueryContext); - filter(chatQueryContext); + MapFilter.filter(chatQueryContext); } catch (Exception e) { log.error("work error", e); } @@ -147,33 +147,4 @@ public abstract class BaseMapper implements SchemaMapper { } return matches; } - - private void filter(ChatQueryContext chatQueryContext) { - MapFilter.filterByDataSetId(chatQueryContext); - MapFilter.filterByDetectWordLenLessThanOne(chatQueryContext); - switch (chatQueryContext.getQueryDataType()) { - case TAG: - MapFilter.filterByQueryDataType( - chatQueryContext, element -> !(element.getIsTag() > 0)); - break; - case METRIC: - MapFilter.filterByQueryDataType( - chatQueryContext, - element -> !SchemaElementType.METRIC.equals(element.getType())); - break; - case DIMENSION: - MapFilter.filterByQueryDataType( - chatQueryContext, - element -> { - boolean isDimensionOrValue = - SchemaElementType.DIMENSION.equals(element.getType()) - || SchemaElementType.VALUE.equals(element.getType()); - return !isDimensionOrValue; - }); - break; - case ALL: - default: - break; - } - } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index 52e540491..e83eed1a3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -49,7 +49,7 @@ public abstract class BaseMatchStrategy implements MatchStr boolean isDeleted = existResults.removeIf( existResult -> { - boolean delete = oneRoundResult.lessSimilar(existResult); + boolean delete = existResult.lessSimilar(oneRoundResult); if (delete) { log.info("deleted existResult:{}", existResult); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java index a68b49872..1218e7ffc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult; +import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; @@ -49,9 +50,8 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy results = new ArrayList<>(); for (Entry> entry : nameToItems.entrySet()) { String name = entry.getKey(); - if (!name.contains(detectSegment) - || mapperHelper.getSimilarity(detectSegment, name) - < metricDimensionThresholdConfig) { + double similarity = EditDistanceUtils.getSimilarity(detectSegment, name); + if (!name.contains(detectSegment) || similarity < metricDimensionThresholdConfig) { continue; } Set schemaElements = entry.getValue(); @@ -68,6 +68,7 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy Lists.partition(queryTextsList, embeddingMapperBatch); for (List queryTextsSub : queryTextsSubList) { - detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext); + List oneRoundResults = + detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext); + selectResultInOneRound(results, oneRoundResults); } return new ArrayList<>(results); } - private void detectByQueryTextsSub( - Set results, + private List detectByQueryTextsSub( Set detectDataSetIds, List queryTextsSub, ChatQueryContext chatQueryContext) { @@ -89,7 +90,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds); if (CollectionUtils.isEmpty(retrieveQueryResults)) { - return; + return new ArrayList<>(); } // step3. build EmbeddingResults List collect = @@ -103,8 +104,8 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy if (!retrieveQueryResult .getQuery() .contains(retrieval.getQuery())) { - return retrieval.getDistance() - > 1 - threshold; + return retrieval.getSimilarity() + < threshold; } return false; }); @@ -150,11 +151,9 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER)); int roundNumber = embeddingRoundNumber * queryTextsSub.size(); - List oneRoundResults = - collect.stream() - .sorted(Comparator.comparingDouble(EmbeddingResult::getDistance)) - .limit(roundNumber) - .collect(Collectors.toList()); - selectResultInOneRound(results, oneRoundResults); + return collect.stream() + .sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity)) + .limit(roundNumber) + .collect(Collectors.toList()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java index 8c43199cf..94198faef 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java @@ -72,10 +72,15 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy hanlpMapResults.stream() .filter( term -> - mapperHelper.getSimilarity(detectSegment, term.getName()) + term.getSimilarity() >= getThresholdMatch( term.getNatures(), chatQueryContext)) .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) + .map( + parseResult -> { + parseResult.setOffset(offset); + return parseResult; + }) .collect(Collectors.toCollection(LinkedHashSet::new)); log.debug( @@ -83,18 +88,6 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy detectSegment, hanlpMapResults); - hanlpMapResults = - hanlpMapResults.stream() - .map( - parseResult -> { - parseResult.setOffset(offset); - parseResult.setSimilarity( - mapperHelper.getSimilarity( - detectSegment, parseResult.getName())); - return parseResult; - }) - .collect(Collectors.toCollection(LinkedHashSet::new)); - // step5. take only M dimensionValue or N-M metric/dimension value per rond. int oneDetectionValueSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE)); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java index 072d7d0e3..84d64a6f1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java @@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; +import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @@ -103,7 +104,6 @@ public class KeywordMapper extends BaseMapper { private void convertDatabaseMapResultToMapInfo( ChatQueryContext chatQueryContext, List mapResults) { - MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); for (DatabaseMapResult match : mapResults) { SchemaElement schemaElement = match.getSchemaElement(); Set regElementSet = @@ -118,7 +118,7 @@ public class KeywordMapper extends BaseMapper { .detectWord(match.getDetectWord()) .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) .similarity( - mapperHelper.getSimilarity( + EditDistanceUtils.getSimilarity( match.getDetectWord(), schemaElement.getName())) .build(); log.info("add to schema, elementMatch {}", schemaElementMatch); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java index a2020570f..23e986a6e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java @@ -7,14 +7,46 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Predicate; +import java.util.stream.Collectors; public class MapFilter { + public static void filter(ChatQueryContext chatQueryContext) { + filterByDataSetId(chatQueryContext); + filterByDetectWordLenLessThanOne(chatQueryContext); + switch (chatQueryContext.getQueryDataType()) { + case TAG: + filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0)); + break; + case METRIC: + filterByQueryDataType( + chatQueryContext, + element -> !SchemaElementType.METRIC.equals(element.getType())); + break; + case DIMENSION: + filterByQueryDataType( + chatQueryContext, + element -> { + boolean isDimensionOrValue = + SchemaElementType.DIMENSION.equals(element.getType()) + || SchemaElementType.VALUE.equals(element.getType()); + return !isDimensionOrValue; + }); + break; + case ALL: + default: + break; + } + filterByRules(chatQueryContext); + } + public static void filterByDataSetId(ChatQueryContext chatQueryContext) { Set dataSetIds = chatQueryContext.getDataSetIds(); if (CollectionUtils.isEmpty(dataSetIds)) { @@ -44,25 +76,93 @@ public class MapFilter { public static void filterByQueryDataType( ChatQueryContext chatQueryContext, Predicate needRemovePredicate) { - chatQueryContext - .getMapInfo() - .getDataSetElementMatches() - .values() - .forEach( - schemaElementMatches -> { - schemaElementMatches.removeIf( - schemaElementMatch -> { - SchemaElement element = schemaElementMatch.getElement(); - SchemaElementType type = element.getType(); + Map> dataSetElementMatches = + chatQueryContext.getMapInfo().getDataSetElementMatches(); + for (Map.Entry> entry : dataSetElementMatches.entrySet()) { + List schemaElementMatches = entry.getValue(); + schemaElementMatches.removeIf( + schemaElementMatch -> { + SchemaElement element = schemaElementMatch.getElement(); + SchemaElementType type = element.getType(); - boolean isEntityOrDatasetOrId = - SchemaElementType.ENTITY.equals(type) - || SchemaElementType.DATASET.equals(type) - || SchemaElementType.ID.equals(type); + boolean isEntityOrDatasetOrId = + SchemaElementType.ENTITY.equals(type) + || SchemaElementType.DATASET.equals(type) + || SchemaElementType.ID.equals(type); - return !isEntityOrDatasetOrId - && needRemovePredicate.test(element); - }); - }); + return !isEntityOrDatasetOrId && needRemovePredicate.test(element); + }); + } + } + + public static void filterByRules(ChatQueryContext chatQueryContext) { + Map> dataSetElementMatches = + chatQueryContext.getMapInfo().getDataSetElementMatches(); + for (Map.Entry> entry : dataSetElementMatches.entrySet()) { + List elementMatches = entry.getValue(); + filterByExactMatch(elementMatches); + filterInExactMatch(elementMatches); + } + } + + public static List filterByExactMatch(List matches) { + // Group by detectWord + Map> groupedByDetectWord = + matches.stream().collect(Collectors.groupingBy(SchemaElementMatch::getDetectWord)); + + List result = new ArrayList<>(); + + for (Map.Entry> entry : groupedByDetectWord.entrySet()) { + List group = entry.getValue(); + + // Filter out objects with similarity=1.0 + List fullMatches = + group.stream() + .filter(SchemaElementMatch::isFullMatched) + .collect(Collectors.toList()); + + if (!fullMatches.isEmpty()) { + // If there are objects with similarity=1.0, choose the one with the longest + // detectWord and smallest offset + SchemaElementMatch bestMatch = + fullMatches.stream() + .max( + Comparator.comparing( + (SchemaElementMatch match) -> + match.getDetectWord().length())) + .orElse(null); + if (bestMatch != null) { + result.add(bestMatch); + } + } else { + // If there are no objects with similarity=1.0, keep all objects with similarity<1.0 + result.addAll(group); + } + } + return result; + } + + public static List filterInExactMatch(List matches) { + Map> fullMatches = + matches.stream() + .filter(schemaElementMatch -> schemaElementMatch.isFullMatched()) + .collect(Collectors.groupingBy(SchemaElementMatch::getDetectWord)); + Set keys = new HashSet<>(fullMatches.keySet()); + for (String key1 : keys) { + for (String key2 : keys) { + if (!key1.equals(key2) && key1.contains(key2)) { + fullMatches.remove(key2); + } + } + } + List notFullMatches = + matches.stream() + .filter(schemaElementMatch -> !schemaElementMatch.isFullMatched()) + .collect(Collectors.toList()); + + List mergedMatches = new ArrayList<>(); + fullMatches.values().forEach(mergedMatches::addAll); + mergedMatches.addAll(notFullMatches); + return mergedMatches; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java index a30e6281c..5566bb00a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.chat.mapper; -import com.hankcs.hanlp.algorithm.EditDistance; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; import lombok.Data; @@ -67,19 +66,4 @@ public class MapperHelper { } return false; } - - /** - * * get similarity - * - * @param detectSegment - * @param matchName - * @return - */ - public double getSimilarity(String detectSegment, String matchName) { - String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase(); - String matchNameLower = matchName == null ? null : matchName.toLowerCase(); - return 1 - - (double) EditDistance.compute(detectSegmentLower, matchNameLower) - / Math.max(matchName.length(), detectSegment.length()); - } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/EditDistanceUtils.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/EditDistanceUtils.java new file mode 100644 index 000000000..4ac8c360c --- /dev/null +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/EditDistanceUtils.java @@ -0,0 +1,27 @@ +package com.tencent.supersonic.headless.chat.utils; + +import com.hankcs.hanlp.algorithm.EditDistance; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +@Data +@Service +@Slf4j +public class EditDistanceUtils { + + /** + * * get similarity + * + * @param detectSegment + * @param matchName + * @return + */ + public static double getSimilarity(String detectSegment, String matchName) { + String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase(); + String matchNameLower = matchName == null ? null : matchName.toLowerCase(); + return 1 + - (double) EditDistance.compute(detectSegmentLower, matchNameLower) + / Math.max(matchName.length(), detectSegment.length()); + } +}