From e00b935c1f6366f47f04c25082b2f876c81e5f8e Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Mon, 6 Nov 2023 17:43:46 +0800 Subject: [PATCH] (improvement)(chat) Perform a complete refactoring of the mapper and integrate the fuzzyNameMapper into the BaseMatchStrategy (#328) --- .../supersonic/chat/mapper/BaseMapper.java | 33 ---- .../chat/mapper/BaseMatchStrategy.java | 51 ++++-- .../chat/mapper/EmbeddingMapper.java | 8 +- .../chat/mapper/EmbeddingMatchStrategy.java | 4 +- .../chat/mapper/FuzzyMatchStrategy.java | 132 +++++++++++++++ .../chat/mapper/FuzzyNameMapper.java | 155 +++--------------- .../chat/mapper/HanlpDictMapper.java | 7 +- .../chat/mapper/HanlpMatchStrategy.java | 10 +- .../supersonic/chat/mapper/MatchStrategy.java | 4 +- .../chat/mapper/SearchMatchStrategy.java | 6 +- .../chat/service/impl/SearchServiceImpl.java | 5 +- .../knowledge/dictionary/FuzzyResult.java | 30 ++++ 12 files changed, 244 insertions(+), 201 deletions(-) create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyMatchStrategy.java create mode 100644 chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/FuzzyResult.java diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java index 48440a70a..e09e325e0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.chat.mapper; -import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.chat.api.component.SchemaMapper; import com.tencent.supersonic.chat.api.pojo.ModelSchema; import com.tencent.supersonic.chat.api.pojo.QueryContext; @@ -10,15 +9,12 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.knowledge.utils.NatureHelper; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; @@ -57,35 +53,6 @@ public abstract class BaseMapper implements SchemaMapper { schemaElementMatches.add(schemaElementMatch); } - public Set getModelIds(QueryContext queryContext) { - return ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest()); - } - - public List filterByModelIds(List terms, Set detectModelIds) { - logTerms(terms); - if (CollectionUtils.isNotEmpty(detectModelIds)) { - terms = terms.stream().filter(term -> { - Long modelId = NatureHelper.getModelId(term.getNature().toString()); - if (Objects.nonNull(modelId)) { - return detectModelIds.contains(modelId); - } - return false; - }).collect(Collectors.toList()); - log.info("terms filter by modelIds:{}", detectModelIds); - logTerms(terms); - } - return terms; - } - - public void logTerms(List terms) { - if (CollectionUtils.isEmpty(terms)) { - return; - } - for (Term term : terms) { - log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency()); - } - } - public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) { SchemaElement element = new SchemaElement(); SemanticService schemaService = ContextUtils.getBean(SemanticService.class); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java index e2265c01b..693c29846 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java @@ -1,7 +1,8 @@ package com.tencent.supersonic.chat.mapper; import com.hankcs.hanlp.seg.common.Term; -import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.knowledge.utils.NatureHelper; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -28,26 +29,25 @@ public abstract class BaseMatchStrategy implements MatchStrategy { @Autowired private MapperHelper mapperHelper; - @Override - public Map> match(QueryReq queryReq, List terms, Set detectModelIds) { - String text = queryReq.getQueryText(); + public Map> match(QueryContext queryContext, List terms, Set detectModelIds) { + String text = queryContext.getRequest().getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; } log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds); - List detects = detect(queryReq, terms, detectModelIds); + List detects = detect(queryContext, terms, detectModelIds); Map> result = new HashMap<>(); result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); return result; } - public List detect(QueryReq queryReq, List terms, Set detectModelIds) { + public List detect(QueryContext queryContext, List terms, Set detectModelIds) { Map regOffsetToLength = getRegOffsetToLength(terms); - String text = queryReq.getQueryText(); + String text = queryContext.getRequest().getQueryText(); Set results = new HashSet<>(); for (Integer index = 0; index <= text.length() - 1; ) { @@ -56,7 +56,7 @@ public abstract class BaseMatchStrategy implements MatchStrategy { int offset = mapperHelper.getStepOffset(terms, index); i = mapperHelper.getStepIndex(regOffsetToLength, i); if (i <= text.length()) { - detectByStep(queryReq, results, detectModelIds, index, i, offset); + detectByStep(queryContext, results, detectModelIds, index, i, offset); } } index = mapperHelper.getStepIndex(regOffsetToLength, index); @@ -94,8 +94,10 @@ public abstract class BaseMatchStrategy implements MatchStrategy { } } - public List getMatches(QueryReq queryReq, List terms, Set detectModelIds) { - Map> matchResult = match(queryReq, terms, detectModelIds); + public List getMatches(QueryContext queryContext, List terms) { + Set detectModelIds = mapperHelper.getModelIds(queryContext.getRequest()); + terms = filterByModelIds(terms, detectModelIds); + Map> matchResult = match(queryContext, terms, detectModelIds); List matches = new ArrayList<>(); if (Objects.isNull(matchResult)) { return matches; @@ -110,12 +112,37 @@ public abstract class BaseMatchStrategy implements MatchStrategy { return matches; } + public List filterByModelIds(List terms, Set detectModelIds) { + logTerms(terms); + if (CollectionUtils.isNotEmpty(detectModelIds)) { + terms = terms.stream().filter(term -> { + Long modelId = NatureHelper.getModelId(term.getNature().toString()); + if (Objects.nonNull(modelId)) { + return detectModelIds.contains(modelId); + } + return false; + }).collect(Collectors.toList()); + log.info("terms filter by modelIds:{}", detectModelIds); + logTerms(terms); + } + return terms; + } + + public void logTerms(List terms) { + if (CollectionUtils.isEmpty(terms)) { + return; + } + for (Term term : terms) { + log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency()); + } + } + public abstract boolean needDelete(T oneRoundResult, T existResult); public abstract String getMapKey(T a); - public abstract void detectByStep(QueryReq queryReq, Set results, Set detectModelIds, Integer startIndex, - Integer index, int offset); + public abstract void detectByStep(QueryContext queryContext, Set results, + Set detectModelIds, Integer startIndex, Integer index, int offset); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java index 3d1d02458..e95bc8db8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java @@ -11,7 +11,6 @@ import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult; import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder; import com.tencent.supersonic.knowledge.utils.HanlpHelper; import java.util.List; -import java.util.Set; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @@ -29,12 +28,7 @@ public class EmbeddingMapper extends BaseMapper { List terms = HanlpHelper.getTerms(queryText); EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); - - Set detectModelIds = getModelIds(queryContext); - - terms = filterByModelIds(terms, detectModelIds); - - List matchResults = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds); + List matchResults = matchStrategy.getMatches(queryContext, terms); HanlpHelper.transLetterOriginal(matchResults); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java index 280beef89..5e2df84f0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.mapper; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.common.pojo.Constants; @@ -46,8 +47,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { return a.getName() + Constants.UNDERLINE + a.getId(); } - public void detectByStep(QueryReq queryReq, Set existResults, Set detectModelIds, + public void detectByStep(QueryContext queryContext, Set existResults, Set detectModelIds, Integer startIndex, Integer index, int offset) { + QueryReq queryReq = queryContext.getRequest(); String detectSegment = queryReq.getQueryText().substring(startIndex, index); // step1. build query params if (StringUtils.isBlank(detectSegment) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyMatchStrategy.java new file mode 100644 index 000000000..80f3c9637 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyMatchStrategy.java @@ -0,0 +1,132 @@ +package com.tencent.supersonic.chat.mapper; + +import com.hankcs.hanlp.seg.common.Term; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.chat.config.OptimizationConfig; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.knowledge.dictionary.FuzzyResult; +import com.tencent.supersonic.knowledge.service.SchemaService; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; + +/** + * Fuzzy Match Strategy + */ +@Service +@Slf4j +public class FuzzyMatchStrategy extends BaseMatchStrategy { + + @Autowired + private OptimizationConfig optimizationConfig; + @Autowired + private MapperHelper mapperHelper; + @Autowired + private SchemaService schemaService; + private List allElements; + + + @Override + public Map> match(QueryContext queryContext, List terms, + Set detectModelIds) { + this.allElements = getSchemaElements(); + return super.match(queryContext, terms, detectModelIds); + } + + @Override + public boolean needDelete(FuzzyResult oneRoundResult, FuzzyResult existResult) { + return getMapKey(oneRoundResult).equals(getMapKey(existResult)) + && existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length(); + } + + @Override + public String getMapKey(FuzzyResult a) { + return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId() + + Constants.UNDERLINE + a.getSchemaElement().getName(); + } + + public void detectByStep(QueryContext queryContext, Set existResults, Set detectModelIds, + Integer startIndex, Integer index, int offset) { + String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index); + // step1. build query params + if (StringUtils.isBlank(detectSegment)) { + return; + } + Set modelIds = mapperHelper.getModelIds(queryContext.getRequest()); + + Double metricDimensionThresholdConfig = getThreshold(queryContext); + + Map> nameToItems = getNameToItems(allElements); + + for (Entry> entry : nameToItems.entrySet()) { + String name = entry.getKey(); + if (!name.contains(detectSegment) + || mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) { + continue; + } + Set schemaElements = entry.getValue(); + if (!CollectionUtils.isEmpty(modelIds)) { + schemaElements = schemaElements.stream() + .filter(schemaElement -> modelIds.contains(schemaElement.getModel())) + .collect(Collectors.toSet()); + } + for (SchemaElement schemaElement : schemaElements) { + FuzzyResult fuzzyResult = new FuzzyResult(); + fuzzyResult.setDetectWord(detectSegment); + fuzzyResult.setName(schemaElement.getName()); + fuzzyResult.setSchemaElement(schemaElement); + existResults.add(fuzzyResult); + } + } + } + + private List getSchemaElements() { + List allElements = new ArrayList<>(); + allElements.addAll(schemaService.getSemanticSchema().getDimensions()); + allElements.addAll(schemaService.getSemanticSchema().getMetrics()); + return allElements; + } + + + private Double getThreshold(QueryContext queryContext) { + Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig(); + Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig(); + + Map> modelElementMatches = queryContext.getMapInfo().getModelElementMatches(); + + boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1); + + if (!existElement) { + double halfThreshold = metricDimensionThresholdConfig / 2; + + metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold + : metricDimensionMinThresholdConfig; + log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}", + modelElementMatches, metricDimensionThresholdConfig); + } + return metricDimensionThresholdConfig; + } + + private Map> getNameToItems(List models) { + return models.stream().collect( + Collectors.toMap(SchemaElement::getName, a -> { + Set result = new HashSet<>(); + result.add(a); + return result; + }, (k1, k2) -> { + k1.addAll(k2); + return k1; + })); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyNameMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyNameMapper.java index eb0548993..a6fb1e85c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyNameMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyNameMapper.java @@ -6,18 +6,11 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; -import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.knowledge.service.SchemaService; +import com.tencent.supersonic.knowledge.dictionary.FuzzyResult; import com.tencent.supersonic.knowledge.utils.HanlpHelper; -import java.util.Comparator; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; @@ -34,145 +27,39 @@ public class FuzzyNameMapper extends BaseMapper { List terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText()); - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - - detectAndAddToSchema(queryContext, terms, semanticSchema.getDimensions(), SchemaElementType.DIMENSION); - - detectAndAddToSchema(queryContext, terms, semanticSchema.getMetrics(), SchemaElementType.METRIC); - - } - - private void detectAndAddToSchema(QueryContext queryContext, List terms, List models, - SchemaElementType schemaElementType) { - try { - - Map> modelResultSet = getResultSet(queryContext, terms, models); - - addToSchemaMapInfo(modelResultSet, queryContext.getMapInfo(), schemaElementType); - - } catch (Exception e) { - log.error("detectAndAddToSchema error", e); - } - } - - private Map> getResultSet(QueryContext queryContext, List terms, - List models) { - - String queryText = queryContext.getRequest().getQueryText(); + FuzzyMatchStrategy fuzzyMatchStrategy = ContextUtils.getBean(FuzzyMatchStrategy.class); MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); - Set modelIds = mapperHelper.getModelIds(queryContext.getRequest()); - Double metricDimensionThresholdConfig = getThreshold(queryContext); + List matches = fuzzyMatchStrategy.getMatches(queryContext, terms); - Map> nameToItems = getNameToItems(models); - - Map regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length)) - .collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2)); - - Map> modelResultSet = new HashMap<>(); - for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) { - for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) { - endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex); - if (endIndex > queryText.length()) { - continue; - } - String detectSegment = queryText.substring(startIndex, endIndex); - - for (Entry> entry : nameToItems.entrySet()) { - String name = entry.getKey(); - Set schemaElements = entry.getValue(); - if (!name.contains(detectSegment) - || mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) { - continue; - } - if (!CollectionUtils.isEmpty(modelIds)) { - schemaElements = schemaElements.stream() - .filter(schemaElement -> modelIds.contains(schemaElement.getModel())) - .collect(Collectors.toSet()); - } - Set preSchemaElements = modelResultSet.putIfAbsent(detectSegment, schemaElements); - if (Objects.nonNull(preSchemaElements)) { - preSchemaElements.addAll(schemaElements); - } - } - } - startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); - } - return modelResultSet; - } - - private Double getThreshold(QueryContext queryContext) { - - OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig(); - Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig(); - - Map> modelElementMatches = queryContext.getMapInfo() - .getModelElementMatches(); - boolean existElement = modelElementMatches.entrySet().stream() - .anyMatch(entry -> entry.getValue().size() >= 1); - - if (!existElement) { - double halfThreshold = metricDimensionThresholdConfig / 2; - - metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold - : metricDimensionMinThresholdConfig; - log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}", - modelElementMatches, metricDimensionThresholdConfig); - } - return metricDimensionThresholdConfig; - } - - private Map> getNameToItems(List models) { - return models.stream().collect( - Collectors.toMap(SchemaElement::getName, a -> { - Set result = new HashSet<>(); - result.add(a); - return result; - }, (k1, k2) -> { - k1.addAll(k2); - return k1; - })); - } - - private void addToSchemaMapInfo(Map> mapResultRowSet, SchemaMapInfo schemaMap, - SchemaElementType schemaElementType) { - if (Objects.isNull(mapResultRowSet) || mapResultRowSet.size() <= 0) { - return; - } - MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); - - for (Map.Entry> entry : mapResultRowSet.entrySet()) { - String detectWord = entry.getKey(); - Set schemaElements = entry.getValue(); - for (SchemaElement schemaElement : schemaElements) { - - Set regElementSet = getRegElementSet(schemaMap, schemaElementType, schemaElement); - if (regElementSet.contains(schemaElement.getId())) { - continue; - } - SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() - .element(schemaElement) - .word(schemaElement.getName()) - .detectWord(detectWord) - .frequency(10000L) - .similarity(mapperHelper.getSimilarity(detectWord, schemaElement.getName())) - .build(); - log.info("schemaElementType:{},add to schema, elementMatch {}", schemaElementType, schemaElementMatch); - addToSchemaMap(schemaMap, schemaElement.getModel(), schemaElementMatch); + for (FuzzyResult match : matches) { + SchemaElement schemaElement = match.getSchemaElement(); + Set regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement); + if (regElementSet.contains(schemaElement.getId())) { + continue; } + SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() + .element(schemaElement) + .word(schemaElement.getName()) + .detectWord(match.getDetectWord()) + .frequency(10000L) + .similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName())) + .build(); + log.info("add to schema, elementMatch {}", schemaElementMatch); + addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch); } } - private Set getRegElementSet(SchemaMapInfo schemaMap, SchemaElementType schemaElementType, - SchemaElement schemaElement) { + private Set getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) { List elements = schemaMap.getMatchedElements(schemaElement.getModel()); if (CollectionUtils.isEmpty(elements)) { return new HashSet<>(); } return elements.stream() - .filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType())) + .filter(elementMatch -> + SchemaElementType.METRIC.equals(elementMatch.getElement().getType()) + || SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType())) .map(elementMatch -> elementMatch.getElement().getId()) .collect(Collectors.toSet()); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpDictMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpDictMapper.java index 554518a84..173227635 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpDictMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpDictMapper.java @@ -13,7 +13,6 @@ import com.tencent.supersonic.knowledge.utils.NatureHelper; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; @@ -33,11 +32,7 @@ public class HanlpDictMapper extends BaseMapper { HanlpMatchStrategy matchStrategy = ContextUtils.getBean(HanlpMatchStrategy.class); - Set detectModelIds = getModelIds(queryContext); - - terms = filterByModelIds(terms, detectModelIds); - - List matches = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds); + List matches = matchStrategy.getMatches(queryContext, terms); HanlpHelper.transLetterOriginal(matches); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpMatchStrategy.java index 96657e4e7..0d8136901 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/HanlpMatchStrategy.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.mapper; import com.hankcs.hanlp.seg.common.Term; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.common.pojo.Constants; @@ -33,7 +34,9 @@ public class HanlpMatchStrategy extends BaseMatchStrategy { private OptimizationConfig optimizationConfig; @Override - public Map> match(QueryReq queryReq, List terms, Set detectModelIds) { + public Map> match(QueryContext queryContext, List terms, + Set detectModelIds) { + QueryReq queryReq = queryContext.getRequest(); String text = queryReq.getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; @@ -41,7 +44,7 @@ public class HanlpMatchStrategy extends BaseMatchStrategy { log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds); - List detects = detect(queryReq, terms, detectModelIds); + List detects = detect(queryContext, terms, detectModelIds); Map> result = new HashMap<>(); result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); @@ -54,8 +57,9 @@ public class HanlpMatchStrategy extends BaseMatchStrategy { && existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length(); } - public void detectByStep(QueryReq queryReq, Set existResults, Set detectModelIds, + public void detectByStep(QueryContext queryContext, Set existResults, Set detectModelIds, Integer startIndex, Integer index, int offset) { + QueryReq queryReq = queryContext.getRequest(); String text = queryReq.getQueryText(); Integer agentId = queryReq.getAgentId(); String detectSegment = text.substring(startIndex, index); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MatchStrategy.java index db25d5f44..924f0f984 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MatchStrategy.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.chat.mapper; import com.hankcs.hanlp.seg.common.Term; -import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import java.util.List; import java.util.Map; import java.util.Set; @@ -11,6 +11,6 @@ import java.util.Set; */ public interface MatchStrategy { - Map> match(QueryReq queryReq, List terms, Set detectModelId); + Map> match(QueryContext queryContext, List terms, Set detectModelId); } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/SearchMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/SearchMatchStrategy.java index 96192c723..88b679463 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/SearchMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/SearchMatchStrategy.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.mapper; import com.google.common.collect.Lists; import com.hankcs.hanlp.seg.common.Term; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult; @@ -25,8 +26,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy { private static final int SEARCH_SIZE = 3; @Override - public Map> match(QueryReq queryReq, List originals, + public Map> match(QueryContext queryContext, List originals, Set detectModelIds) { + QueryReq queryReq = queryContext.getRequest(); String text = queryReq.getQueryText(); Map regOffsetToLength = getRegOffsetToLength(originals); @@ -87,7 +89,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy { } @Override - public void detectByStep(QueryReq queryReq, Set results, Set detectModelIds, + public void detectByStep(QueryContext queryContext, Set results, Set detectModelIds, Integer startIndex, Integer i, int offset) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java index c3d79f78b..343fc2222 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java @@ -4,6 +4,7 @@ import com.github.benmanes.caffeine.cache.Cache; import com.google.common.collect.Lists; import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; @@ -94,8 +95,10 @@ public class SearchServiceImpl implements SearchService { MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); Set detectModelIds = mapperHelper.getModelIds(queryReq); + QueryContext queryContext = new QueryContext(); + queryContext.setRequest(queryReq); Map> regTextMap = - searchMatchStrategy.match(queryReq, originals, detectModelIds); + searchMatchStrategy.match(queryContext, originals, detectModelIds); regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue())); // 4.get the most matching data diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/FuzzyResult.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/FuzzyResult.java new file mode 100644 index 000000000..ee3e0e67c --- /dev/null +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/FuzzyResult.java @@ -0,0 +1,30 @@ +package com.tencent.supersonic.knowledge.dictionary; + +import com.google.common.base.Objects; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import lombok.Data; +import lombok.ToString; + +@Data +@ToString +public class FuzzyResult extends MapResult { + + private SchemaElement schemaElement; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FuzzyResult that = (FuzzyResult) o; + return Objects.equal(name, that.name) && Objects.equal(schemaElement, that.schemaElement); + } + + @Override + public int hashCode() { + return Objects.hashCode(name, schemaElement); + } +} \ No newline at end of file