diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DatabaseMapResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DatabaseMapResult.java index ba5a0e9fd..aec5ddcea 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DatabaseMapResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DatabaseMapResult.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.knowledge; import com.google.common.base.Objects; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import lombok.Data; import lombok.ToString; @@ -27,4 +28,13 @@ public class DatabaseMapResult extends MapResult { public int hashCode() { return Objects.hashCode(name, schemaElement); } + + @Override + public String getMapKey() { + return this.getName() + + Constants.UNDERLINE + + this.getSchemaElement().getId() + + Constants.UNDERLINE + + this.getSchemaElement().getName(); + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java index 1eebf3d0c..282a59063 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.knowledge; import com.google.common.base.Objects; +import com.tencent.supersonic.common.pojo.Constants; import lombok.Data; import lombok.ToString; @@ -32,4 +33,9 @@ public class EmbeddingResult extends MapResult { public int hashCode() { return Objects.hashCode(id); } + + @Override + public String getMapKey() { + return this.getName() + Constants.UNDERLINE + this.getId(); + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java index 767bcd141..9d8789a30 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.knowledge; import com.google.common.base.Objects; +import com.tencent.supersonic.common.pojo.Constants; import lombok.Data; import lombok.ToString; @@ -42,4 +43,11 @@ public class HanlpMapResult extends MapResult { public void setOffset(int offset) { this.offset = offset; } + + @Override + public String getMapKey() { + return this.getName() + + Constants.UNDERLINE + + String.join(Constants.UNDERLINE, this.getNatures()); + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MapResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MapResult.java index ee7213e97..616bfaf92 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MapResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MapResult.java @@ -7,8 +7,19 @@ import java.io.Serializable; @Data @ToString -public class MapResult implements Serializable { +public abstract class MapResult implements Serializable { protected String name; + + protected int index; protected String detectWord; + + public abstract String getMapKey(); + + public Boolean lessSimilar(MapResult otherResult) { + String mapKey = this.getMapKey(); + String otherMapKey = otherResult.getMapKey(); + return mapKey.equals(otherMapKey) + && otherResult.getDetectWord().length() < otherResult.getDetectWord().length(); + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java index 4603cd796..28aaa12c3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java @@ -31,6 +31,8 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; /** HanLP helper */ @@ -87,7 +89,7 @@ public class HanlpHelper { return CustomDictionary; } - /** * reload custom dictionary */ + /** reload custom dictionary */ public static boolean reloadCustomDictionary() throws IOException { final long startTime = System.currentTimeMillis(); @@ -316,6 +318,28 @@ public class HanlpHelper { .collect(Collectors.toList()); } + public static List getTerms(List terms, Set dataSetIds) { + logTerms(terms); + if (!CollectionUtils.isEmpty(dataSetIds)) { + terms = + terms.stream() + .filter( + term -> { + Long dataSetId = + NatureHelper.getDataSetId( + term.getNature().toString()); + if (Objects.nonNull(dataSetId)) { + return dataSetIds.contains(dataSetId); + } + return false; + }) + .collect(Collectors.toList()); + log.debug("terms filter by dataSetId:{}", dataSetIds); + logTerms(terms); + } + return terms; + } + public static List transform2ApiTerm( Term term, Map> modelIdToDataSetIds) { List s2Terms = Lists.newArrayList(); @@ -331,4 +355,17 @@ public class HanlpHelper { } return s2Terms; } + + private static void logTerms(List terms) { + if (CollectionUtils.isEmpty(terms)) { + return; + } + for (S2Term term : terms) { + log.debug( + "word:{},nature:{},frequency:{}", + term.word, + term.nature.toString(), + term.getFrequency()); + } + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java index a718756af..77e825017 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java @@ -6,20 +6,20 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; -import org.springframework.util.CollectionUtils; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; +import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Predicate; import java.util.stream.Collectors; @Slf4j @@ -50,85 +50,6 @@ public abstract class BaseMapper implements SchemaMapper { chatQueryContext.getMapInfo().getDataSetElementMatches()); } - private void filter(ChatQueryContext chatQueryContext) { - filterByDataSetId(chatQueryContext); - filterByDetectWordLenLessThanOne(chatQueryContext); - switch (chatQueryContext.getQueryDataType()) { - case TAG: - filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0)); - break; - case METRIC: - filterByQueryDataType( - chatQueryContext, - element -> !SchemaElementType.METRIC.equals(element.getType())); - break; - case DIMENSION: - filterByQueryDataType( - chatQueryContext, - element -> { - boolean isDimensionOrValue = - SchemaElementType.DIMENSION.equals(element.getType()) - || SchemaElementType.VALUE.equals(element.getType()); - return !isDimensionOrValue; - }); - break; - case ALL: - default: - break; - } - } - - private static void filterByDataSetId(ChatQueryContext chatQueryContext) { - Set dataSetIds = chatQueryContext.getDataSetIds(); - if (CollectionUtils.isEmpty(dataSetIds)) { - return; - } - Set dataSetIdInMapInfo = - new HashSet<>(chatQueryContext.getMapInfo().getDataSetElementMatches().keySet()); - for (Long dataSetId : dataSetIdInMapInfo) { - if (!dataSetIds.contains(dataSetId)) { - chatQueryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId); - } - } - } - - private static void filterByDetectWordLenLessThanOne(ChatQueryContext chatQueryContext) { - Map> dataSetElementMatches = - chatQueryContext.getMapInfo().getDataSetElementMatches(); - for (Map.Entry> entry : dataSetElementMatches.entrySet()) { - List value = entry.getValue(); - if (!CollectionUtils.isEmpty(value)) { - value.removeIf( - schemaElementMatch -> - StringUtils.length(schemaElementMatch.getDetectWord()) <= 1); - } - } - } - - private static void filterByQueryDataType( - ChatQueryContext chatQueryContext, Predicate needRemovePredicate) { - chatQueryContext - .getMapInfo() - .getDataSetElementMatches() - .values() - .forEach( - schemaElementMatches -> { - schemaElementMatches.removeIf( - schemaElementMatch -> { - SchemaElement element = schemaElementMatch.getElement(); - SchemaElementType type = element.getType(); - - boolean isEntityOrDatasetOrId = - SchemaElementType.ENTITY.equals(type) - || SchemaElementType.DATASET.equals(type) - || SchemaElementType.ID.equals(type); - - return !isEntityOrDatasetOrId - && needRemovePredicate.test(element); - }); - }); - } - public abstract void doMap(ChatQueryContext chatQueryContext); public void addToSchemaMap( @@ -202,4 +123,57 @@ public abstract class BaseMapper implements SchemaMapper { } return element.getAlias(); } + + public List getMatches( + ChatQueryContext chatQueryContext, BaseMatchStrategy matchStrategy) { + String queryText = chatQueryContext.getQueryText(); + List terms = + HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); + terms = HanlpHelper.getTerms(terms, chatQueryContext.getDataSetIds()); + Map> matchResult = + matchStrategy.match(chatQueryContext, terms, chatQueryContext.getDataSetIds()); + List matches = new ArrayList<>(); + if (Objects.isNull(matchResult)) { + return matches; + } + Optional> first = + matchResult.entrySet().stream() + .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) + .map(entry -> entry.getValue()) + .findFirst(); + + if (first.isPresent()) { + matches = first.get(); + } + return matches; + } + + private void filter(ChatQueryContext chatQueryContext) { + MapFilter.filterByDataSetId(chatQueryContext); + MapFilter.filterByDetectWordLenLessThanOne(chatQueryContext); + switch (chatQueryContext.getQueryDataType()) { + case TAG: + MapFilter.filterByQueryDataType( + chatQueryContext, element -> !(element.getIsTag() > 0)); + break; + case METRIC: + MapFilter.filterByQueryDataType( + chatQueryContext, + element -> !SchemaElementType.METRIC.equals(element.getType())); + break; + case DIMENSION: + MapFilter.filterByQueryDataType( + chatQueryContext, + element -> { + boolean isDimensionOrValue = + SchemaElementType.DIMENSION.equals(element.getType()) + || SchemaElementType.VALUE.equals(element.getType()); + return !isDimensionOrValue; + }); + break; + case ALL: + default: + break; + } + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index d6cb1abd1..52e540491 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -3,32 +3,21 @@ package com.tencent.supersonic.headless.chat.mapper; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; -import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; +import com.tencent.supersonic.headless.chat.knowledge.MapResult; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; -import java.util.ArrayList; -import java.util.Comparator; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; @Service @Slf4j -public abstract class BaseMatchStrategy implements MatchStrategy { - - @Autowired protected MapperHelper mapperHelper; - - @Autowired protected MapperConfig mapperConfig; - +public abstract class BaseMatchStrategy implements MatchStrategy { @Override public Map> match( ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { @@ -48,37 +37,7 @@ public abstract class BaseMatchStrategy implements MatchStrategy { public List detect( ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - Map regOffsetToLength = getRegOffsetToLength(terms); - String text = chatQueryContext.getQueryText(); - Set results = new HashSet<>(); - - Set detectSegments = new HashSet<>(); - - for (Integer startIndex = 0; startIndex <= text.length() - 1; ) { - - for (Integer index = startIndex; index <= text.length(); ) { - int offset = mapperHelper.getStepOffset(terms, startIndex); - index = mapperHelper.getStepIndex(regOffsetToLength, index); - if (index <= text.length()) { - String detectSegment = text.substring(startIndex, index).trim(); - detectSegments.add(detectSegment); - detectByStep( - chatQueryContext, results, detectDataSetIds, detectSegment, offset); - } - } - startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); - } - return new ArrayList<>(results); - } - - public Map getRegOffsetToLength(List terms) { - return terms.stream() - .sorted(Comparator.comparing(S2Term::length)) - .collect( - Collectors.toMap( - S2Term::getOffset, - term -> term.word.length(), - (value1, value2) -> value2)); + throw new RuntimeException("Not implemented"); } public void selectResultInOneRound(Set existResults, List oneRoundResults) { @@ -90,7 +49,7 @@ public abstract class BaseMatchStrategy implements MatchStrategy { boolean isDeleted = existResults.removeIf( existResult -> { - boolean delete = needDelete(oneRoundResult, existResult); + boolean delete = oneRoundResult.lessSimilar(existResult); if (delete) { log.info("deleted existResult:{}", existResult); } @@ -106,72 +65,6 @@ public abstract class BaseMatchStrategy implements MatchStrategy { } } - public List getMatches(ChatQueryContext chatQueryContext, List terms) { - Set dataSetIds = chatQueryContext.getDataSetIds(); - terms = filterByDataSetId(terms, dataSetIds); - Map> matchResult = match(chatQueryContext, terms, dataSetIds); - List matches = new ArrayList<>(); - if (Objects.isNull(matchResult)) { - return matches; - } - Optional> first = - matchResult.entrySet().stream() - .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) - .map(entry -> entry.getValue()) - .findFirst(); - - if (first.isPresent()) { - matches = first.get(); - } - return matches; - } - - public List filterByDataSetId(List terms, Set dataSetIds) { - logTerms(terms); - if (CollectionUtils.isNotEmpty(dataSetIds)) { - terms = - terms.stream() - .filter( - term -> { - Long dataSetId = - NatureHelper.getDataSetId( - term.getNature().toString()); - if (Objects.nonNull(dataSetId)) { - return dataSetIds.contains(dataSetId); - } - return false; - }) - .collect(Collectors.toList()); - log.debug("terms filter by dataSetId:{}", dataSetIds); - logTerms(terms); - } - return terms; - } - - public void logTerms(List terms) { - if (CollectionUtils.isEmpty(terms)) { - return; - } - for (S2Term term : terms) { - log.debug( - "word:{},nature:{},frequency:{}", - term.word, - term.nature.toString(), - term.getFrequency()); - } - } - - public abstract boolean needDelete(T oneRoundResult, T existResult); - - public abstract String getMapKey(T a); - - public abstract void detectByStep( - ChatQueryContext chatQueryContext, - Set existResults, - Set detectDataSetIds, - String detectSegment, - int offset); - public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) { double decreaseAmount = (threshold - minThreshold) / 4; double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java new file mode 100644 index 000000000..8c4f020e3 --- /dev/null +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java @@ -0,0 +1,47 @@ +package com.tencent.supersonic.headless.chat.mapper; + +import com.tencent.supersonic.headless.api.pojo.response.S2Term; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.knowledge.MapResult; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +@Service +@Slf4j +public abstract class BatchMatchStrategy extends BaseMatchStrategy { + + @Autowired protected MapperConfig mapperConfig; + + @Override + public List detect( + ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { + + String text = chatQueryContext.getQueryText(); + Set detectSegments = new HashSet<>(); + + int embeddingTextSize = + Integer.valueOf( + mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE)); + + int embeddingTextStep = + Integer.valueOf( + mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP)); + + for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) { + int endIndex = Math.min(startIndex + embeddingTextSize, text.length()); + String detectSegment = text.substring(startIndex, endIndex).trim(); + detectSegments.add(detectSegment); + } + return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments); + } + + public abstract List detectByBatch( + ChatQueryContext chatQueryContext, + Set detectDataSetIds, + Set detectSegments); +} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java index f9dd9d6c1..a68b49872 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.chat.mapper; -import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.response.S2Term; @@ -25,7 +24,7 @@ import java.util.stream.Collectors; */ @Service @Slf4j -public class DatabaseMatchStrategy extends BaseMatchStrategy { +public class DatabaseMatchStrategy extends SingleMatchStrategy { private List allElements; @@ -36,34 +35,18 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy return super.match(chatQueryContext, terms, detectDataSetIds); } - @Override - public boolean needDelete(DatabaseMapResult oneRoundResult, DatabaseMapResult existResult) { - return getMapKey(oneRoundResult).equals(getMapKey(existResult)) - && existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length(); - } - - @Override - public String getMapKey(DatabaseMapResult a) { - return a.getName() - + Constants.UNDERLINE - + a.getSchemaElement().getId() - + Constants.UNDERLINE - + a.getSchemaElement().getName(); - } - - public void detectByStep( + public List detectByStep( ChatQueryContext chatQueryContext, - Set existResults, Set detectDataSetIds, String detectSegment, int offset) { if (StringUtils.isBlank(detectSegment)) { - return; + return new ArrayList<>(); } Double metricDimensionThresholdConfig = getThreshold(chatQueryContext); Map> nameToItems = getNameToItems(allElements); - + List results = new ArrayList<>(); for (Entry> entry : nameToItems.entrySet()) { String name = entry.getKey(); if (!name.contains(detectSegment) @@ -86,9 +69,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy databaseMapResult.setDetectWord(detectSegment); databaseMapResult.setName(schemaElement.getName()); databaseMapResult.setSchemaElement(schemaElement); - existResults.add(databaseMapResult); + results.add(databaseMapResult); } } + return results; } private List getSchemaElements(ChatQueryContext chatQueryContext) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java index 44adfc396..5bea95848 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java @@ -4,7 +4,6 @@ import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; -import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder; @@ -15,19 +14,15 @@ import lombok.extern.slf4j.Slf4j; import java.util.List; import java.util.Objects; -/** * A mapper that recognizes schema elements with vector embedding. */ +/** A mapper that recognizes schema elements with vector embedding. */ @Slf4j public class EmbeddingMapper extends BaseMapper { @Override public void doMap(ChatQueryContext chatQueryContext) { // 1. query from embedding by queryText - String queryText = chatQueryContext.getQueryText(); - List terms = - HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); - EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); - List matchResults = matchStrategy.getMatches(chatQueryContext, terms); + List matchResults = getMatches(chatQueryContext, matchStrategy); HanlpHelper.transLetterOriginal(matchResults); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index 8ebffa519..680a152be 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -1,8 +1,6 @@ package com.tencent.supersonic.headless.chat.mapper; import com.google.common.collect.Lists; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; @@ -35,54 +33,12 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING */ @Service @Slf4j -public class EmbeddingMatchStrategy extends BaseMatchStrategy { +public class EmbeddingMatchStrategy extends BatchMatchStrategy { @Autowired private MetaEmbeddingService metaEmbeddingService; @Override - public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) { - return getMapKey(oneRoundResult).equals(getMapKey(existResult)) - && existResult.getDistance() > oneRoundResult.getDistance(); - } - - @Override - public String getMapKey(EmbeddingResult a) { - return a.getName() + Constants.UNDERLINE + a.getId(); - } - - @Override - public void detectByStep( - ChatQueryContext chatQueryContext, - Set existResults, - Set detectDataSetIds, - String detectSegment, - int offset) {} - - @Override - public List detect( - ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - String text = chatQueryContext.getQueryText(); - Set detectSegments = new HashSet<>(); - - int embeddingTextSize = - Integer.valueOf( - mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE)); - - int embeddingTextStep = - Integer.valueOf( - mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP)); - - for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) { - int endIndex = Math.min(startIndex + embeddingTextSize, text.length()); - String detectSegment = text.substring(startIndex, endIndex).trim(); - detectSegments.add(detectSegment); - } - Set results = - detectByBatch(chatQueryContext, detectDataSetIds, detectSegments); - return new ArrayList<>(results); - } - - protected Set detectByBatch( + public List detectByBatch( ChatQueryContext chatQueryContext, Set detectDataSetIds, Set detectSegments) { @@ -103,7 +59,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { for (List queryTextsSub : queryTextsSubList) { detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext); } - return results; + return new ArrayList<>(results); } private void detectByQueryTextsSub( diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java index 11cb16383..8c43199cf 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java @@ -1,22 +1,16 @@ package com.tencent.supersonic.headless.chat.mapper; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.ArrayList; -import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; -import java.util.Map; -import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -30,36 +24,12 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DI */ @Service @Slf4j -public class HanlpDictMatchStrategy extends BaseMatchStrategy { +public class HanlpDictMatchStrategy extends SingleMatchStrategy { @Autowired private KnowledgeBaseService knowledgeBaseService; - @Override - public Map> match( - ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - String text = chatQueryContext.getQueryText(); - if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { - return null; - } - - log.debug("terms:{},detectModelIds:{}", terms, detectDataSetIds); - - List detects = detect(chatQueryContext, terms, detectDataSetIds); - Map> result = new HashMap<>(); - - result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); - return result; - } - - @Override - public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) { - return getMapKey(oneRoundResult).equals(getMapKey(existResult)) - && existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length(); - } - - public void detectByStep( + public List detectByStep( ChatQueryContext chatQueryContext, - Set existResults, Set detectDataSetIds, String detectSegment, int offset) { @@ -89,7 +59,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { hanlpMapResults.addAll(suffixHanlpMapResults); if (CollectionUtils.isEmpty(hanlpMapResults)) { - return; + return new ArrayList<>(); } // step3. merge pre/suffix result hanlpMapResults = @@ -155,12 +125,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { .collect(Collectors.toList()); oneRoundResults.addAll(additionalResults); } - // step6. select mapResul in one round - selectResultInOneRound(existResults, oneRoundResults); - } - - public String getMapKey(HanlpMapResult a) { - return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures()); + return oneRoundResults; } public double getThresholdMatch(List natures, ChatQueryContext chatQueryContext) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java index df0bb74d4..072d7d0e3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java @@ -38,16 +38,15 @@ public class KeywordMapper extends BaseMapper { HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class); - List hanlpMapResults = - hanlpMatchStrategy.getMatches(chatQueryContext, terms); - convertHanlpMapResultToMapInfo(hanlpMapResults, chatQueryContext, terms); + List matchResults = getMatches(chatQueryContext, hanlpMatchStrategy); + + convertHanlpMapResultToMapInfo(matchResults, chatQueryContext, terms); // 2.database Match DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class); - List databaseResults = - databaseMatchStrategy.getMatches(chatQueryContext, terms); + getMatches(chatQueryContext, databaseMatchStrategy); convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java new file mode 100644 index 000000000..a2020570f --- /dev/null +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java @@ -0,0 +1,68 @@ +package com.tencent.supersonic.headless.chat.mapper; + +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; + +public class MapFilter { + + public static void filterByDataSetId(ChatQueryContext chatQueryContext) { + Set dataSetIds = chatQueryContext.getDataSetIds(); + if (CollectionUtils.isEmpty(dataSetIds)) { + return; + } + Set dataSetIdInMapInfo = + new HashSet<>(chatQueryContext.getMapInfo().getDataSetElementMatches().keySet()); + for (Long dataSetId : dataSetIdInMapInfo) { + if (!dataSetIds.contains(dataSetId)) { + chatQueryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId); + } + } + } + + public static void filterByDetectWordLenLessThanOne(ChatQueryContext chatQueryContext) { + Map> dataSetElementMatches = + chatQueryContext.getMapInfo().getDataSetElementMatches(); + for (Map.Entry> entry : dataSetElementMatches.entrySet()) { + List value = entry.getValue(); + if (!CollectionUtils.isEmpty(value)) { + value.removeIf( + schemaElementMatch -> + StringUtils.length(schemaElementMatch.getDetectWord()) <= 1); + } + } + } + + public static void filterByQueryDataType( + ChatQueryContext chatQueryContext, Predicate needRemovePredicate) { + chatQueryContext + .getMapInfo() + .getDataSetElementMatches() + .values() + .forEach( + schemaElementMatches -> { + schemaElementMatches.removeIf( + schemaElementMatch -> { + SchemaElement element = schemaElementMatch.getElement(); + SchemaElementType type = element.getType(); + + boolean isEntityOrDatasetOrId = + SchemaElementType.ENTITY.equals(type) + || SchemaElementType.DATASET.equals(type) + || SchemaElementType.ID.equals(type); + + return !isEntityOrDatasetOrId + && needRemovePredicate.test(element); + }); + }); + } +} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java index fe6373e12..a30e6281c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java @@ -43,6 +43,16 @@ public class MapperHelper { return index; } + public Map getRegOffsetToLength(List terms) { + return terms.stream() + .sorted(Comparator.comparing(S2Term::length)) + .collect( + Collectors.toMap( + S2Term::getOffset, + term -> term.word.length(), + (value1, value2) -> value2)); + } + /** * * exist dimension values * @@ -58,15 +68,6 @@ public class MapperHelper { return false; } - public boolean existTerms(List natures) { - for (String nature : natures) { - if (NatureHelper.isTermNature(nature)) { - return true; - } - } - return false; - } - /** * * get similarity * diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java index af0bf9ed7..049b138f8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.knowledge.MapResult; import java.util.List; import java.util.Map; @@ -10,7 +11,7 @@ import java.util.Set; /** * MatchStrategy encapsulates a concrete matching algorithm executed during query or search process. */ -public interface MatchStrategy { +public interface MatchStrategy { Map> match( ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java index 9a67ad75d..4afd02b9c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java @@ -29,11 +29,13 @@ public class SearchMatchStrategy extends BaseMatchStrategy { @Autowired private KnowledgeBaseService knowledgeBaseService; + @Autowired private MapperHelper mapperHelper; + @Override public Map> match( ChatQueryContext chatQueryContext, List originals, Set detectDataSetIds) { String text = chatQueryContext.getQueryText(); - Map regOffsetToLength = getRegOffsetToLength(originals); + Map regOffsetToLength = mapperHelper.getRegOffsetToLength(originals); List detectIndexList = Lists.newArrayList(); @@ -104,22 +106,4 @@ public class SearchMatchStrategy extends BaseMatchStrategy { }); return regTextMap; } - - @Override - public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) { - return false; - } - - @Override - public String getMapKey(HanlpMapResult a) { - return null; - } - - @Override - public void detectByStep( - ChatQueryContext chatQueryContext, - Set existResults, - Set detectDataSetIds, - String detectSegment, - int offset) {} } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java new file mode 100644 index 000000000..6b1b69b5e --- /dev/null +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java @@ -0,0 +1,53 @@ +package com.tencent.supersonic.headless.chat.mapper; + +import com.tencent.supersonic.headless.api.pojo.response.S2Term; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.knowledge.MapResult; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +@Service +@Slf4j +public abstract class SingleMatchStrategy extends BaseMatchStrategy { + @Autowired protected MapperConfig mapperConfig; + @Autowired protected MapperHelper mapperHelper; + + public List detect( + ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { + Map regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); + String text = chatQueryContext.getQueryText(); + Set results = new HashSet<>(); + + Set detectSegments = new HashSet<>(); + + for (Integer startIndex = 0; startIndex <= text.length() - 1; ) { + + for (Integer index = startIndex; index <= text.length(); ) { + int offset = mapperHelper.getStepOffset(terms, startIndex); + index = mapperHelper.getStepIndex(regOffsetToLength, index); + if (index <= text.length()) { + String detectSegment = text.substring(startIndex, index).trim(); + detectSegments.add(detectSegment); + List oneRoundResults = + detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset); + selectResultInOneRound(results, oneRoundResults); + } + } + startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); + } + return new ArrayList<>(results); + } + + public abstract List detectByStep( + ChatQueryContext chatQueryContext, + Set detectDataSetIds, + String detectSegment, + int offset); +}