(improvement)(chat) Perform a complete refactoring of the mapper and integrate the fuzzyNameMapper into the BaseMatchStrategy (#328)

This commit is contained in:
lexluo09
2023-11-06 17:43:46 +08:00
committed by GitHub
parent f5f9c0314a
commit e00b935c1f
12 changed files with 244 additions and 201 deletions

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -10,15 +9,12 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
@@ -57,35 +53,6 @@ public abstract class BaseMapper implements SchemaMapper {
schemaElementMatches.add(schemaElementMatch);
}
public Set<Long> getModelIds(QueryContext queryContext) {
return ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest());
}
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
logTerms(terms);
if (CollectionUtils.isNotEmpty(detectModelIds)) {
terms = terms.stream().filter(term -> {
Long modelId = NatureHelper.getModelId(term.getNature().toString());
if (Objects.nonNull(modelId)) {
return detectModelIds.contains(modelId);
}
return false;
}).collect(Collectors.toList());
log.info("terms filter by modelIds:{}", detectModelIds);
logTerms(terms);
}
return terms;
}
public void logTerms(List<Term> terms) {
if (CollectionUtils.isEmpty(terms)) {
return;
}
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
}
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
SchemaElement element = new SchemaElement();
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -28,26 +29,25 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Autowired
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<T>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
String text = queryContext.getRequest().getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
List<T> detects = detect(queryReq, terms, detectModelIds);
List<T> detects = detect(queryContext, terms, detectModelIds);
Map<MatchText, List<T>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
public List<T> detect(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
String text = queryReq.getQueryText();
String text = queryContext.getRequest().getQueryText();
Set<T> results = new HashSet<>();
for (Integer index = 0; index <= text.length() - 1; ) {
@@ -56,7 +56,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
int offset = mapperHelper.getStepOffset(terms, index);
i = mapperHelper.getStepIndex(regOffsetToLength, i);
if (i <= text.length()) {
detectByStep(queryReq, results, detectModelIds, index, i, offset);
detectByStep(queryContext, results, detectModelIds, index, i, offset);
}
}
index = mapperHelper.getStepIndex(regOffsetToLength, index);
@@ -94,8 +94,10 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
}
}
public List<T> getMatches(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
Map<MatchText, List<T>> matchResult = match(queryReq, terms, detectModelIds);
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
terms = filterByModelIds(terms, detectModelIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
List<T> matches = new ArrayList<>();
if (Objects.isNull(matchResult)) {
return matches;
@@ -110,12 +112,37 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
return matches;
}
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
logTerms(terms);
if (CollectionUtils.isNotEmpty(detectModelIds)) {
terms = terms.stream().filter(term -> {
Long modelId = NatureHelper.getModelId(term.getNature().toString());
if (Objects.nonNull(modelId)) {
return detectModelIds.contains(modelId);
}
return false;
}).collect(Collectors.toList());
log.info("terms filter by modelIds:{}", detectModelIds);
logTerms(terms);
}
return terms;
}
public void logTerms(List<Term> terms) {
if (CollectionUtils.isEmpty(terms)) {
return;
}
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
}
public abstract boolean needDelete(T oneRoundResult, T existResult);
public abstract String getMapKey(T a);
public abstract void detectByStep(QueryReq queryReq, Set<T> results, Set<Long> detectModelIds, Integer startIndex,
Integer index, int offset);
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
}

View File

@@ -11,7 +11,6 @@ import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.List;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -29,12 +28,7 @@ public class EmbeddingMapper extends BaseMapper {
List<Term> terms = HanlpHelper.getTerms(queryText);
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
Set<Long> detectModelIds = getModelIds(queryContext);
terms = filterByModelIds(terms, detectModelIds);
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds);
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
HanlpHelper.transLetterOriginal(matchResults);

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants;
@@ -46,8 +47,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
return a.getName() + Constants.UNDERLINE + a.getId();
}
public void detectByStep(QueryReq queryReq, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) {
QueryReq queryReq = queryContext.getRequest();
String detectSegment = queryReq.getQueryText().substring(startIndex, index);
// step1. build query params
if (StringUtils.isBlank(detectSegment)

View File

@@ -0,0 +1,132 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
/**
* Fuzzy Match Strategy
*/
@Service
@Slf4j
public class FuzzyMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MapperHelper mapperHelper;
@Autowired
private SchemaService schemaService;
private List<SchemaElement> allElements;
@Override
public Map<MatchText, List<FuzzyResult>> match(QueryContext queryContext, List<Term> terms,
Set<Long> detectModelIds) {
this.allElements = getSchemaElements();
return super.match(queryContext, terms, detectModelIds);
}
@Override
public boolean needDelete(FuzzyResult oneRoundResult, FuzzyResult existResult) {
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
}
@Override
public String getMapKey(FuzzyResult a) {
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
+ Constants.UNDERLINE + a.getSchemaElement().getName();
}
public void detectByStep(QueryContext queryContext, Set<FuzzyResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) {
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
// step1. build query params
if (StringUtils.isBlank(detectSegment)) {
return;
}
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
Double metricDimensionThresholdConfig = getThreshold(queryContext);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
String name = entry.getKey();
if (!name.contains(detectSegment)
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
continue;
}
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(modelIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {
FuzzyResult fuzzyResult = new FuzzyResult();
fuzzyResult.setDetectWord(detectSegment);
fuzzyResult.setName(schemaElement.getName());
fuzzyResult.setSchemaElement(schemaElement);
existResults.add(fuzzyResult);
}
}
}
private List<SchemaElement> getSchemaElements() {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
return allElements;
}
private Double getThreshold(QueryContext queryContext) {
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) {
double halfThreshold = metricDimensionThresholdConfig / 2;
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
: metricDimensionMinThresholdConfig;
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
modelElementMatches, metricDimensionThresholdConfig);
}
return metricDimensionThresholdConfig;
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
return models.stream().collect(
Collectors.toMap(SchemaElement::getName, a -> {
Set<SchemaElement> result = new HashSet<>();
result.add(a);
return result;
}, (k1, k2) -> {
k1.addAll(k2);
return k1;
}));
}
}

View File

@@ -6,18 +6,11 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
@@ -34,145 +27,39 @@ public class FuzzyNameMapper extends BaseMapper {
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
detectAndAddToSchema(queryContext, terms, semanticSchema.getDimensions(), SchemaElementType.DIMENSION);
detectAndAddToSchema(queryContext, terms, semanticSchema.getMetrics(), SchemaElementType.METRIC);
}
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> models,
SchemaElementType schemaElementType) {
try {
Map<String, Set<SchemaElement>> modelResultSet = getResultSet(queryContext, terms, models);
addToSchemaMapInfo(modelResultSet, queryContext.getMapInfo(), schemaElementType);
} catch (Exception e) {
log.error("detectAndAddToSchema error", e);
}
}
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
List<SchemaElement> models) {
String queryText = queryContext.getRequest().getQueryText();
FuzzyMatchStrategy fuzzyMatchStrategy = ContextUtils.getBean(FuzzyMatchStrategy.class);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
Double metricDimensionThresholdConfig = getThreshold(queryContext);
List<FuzzyResult> matches = fuzzyMatchStrategy.getMatches(queryContext, terms);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(models);
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
Map<String, Set<SchemaElement>> modelResultSet = new HashMap<>();
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
if (endIndex > queryText.length()) {
continue;
}
String detectSegment = queryText.substring(startIndex, endIndex);
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
String name = entry.getKey();
Set<SchemaElement> schemaElements = entry.getValue();
if (!name.contains(detectSegment)
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
continue;
}
if (!CollectionUtils.isEmpty(modelIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
.collect(Collectors.toSet());
}
Set<SchemaElement> preSchemaElements = modelResultSet.putIfAbsent(detectSegment, schemaElements);
if (Objects.nonNull(preSchemaElements)) {
preSchemaElements.addAll(schemaElements);
}
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
return modelResultSet;
}
private Double getThreshold(QueryContext queryContext) {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
.getModelElementMatches();
boolean existElement = modelElementMatches.entrySet().stream()
.anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) {
double halfThreshold = metricDimensionThresholdConfig / 2;
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
: metricDimensionMinThresholdConfig;
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
modelElementMatches, metricDimensionThresholdConfig);
}
return metricDimensionThresholdConfig;
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
return models.stream().collect(
Collectors.toMap(SchemaElement::getName, a -> {
Set<SchemaElement> result = new HashSet<>();
result.add(a);
return result;
}, (k1, k2) -> {
k1.addAll(k2);
return k1;
}));
}
private void addToSchemaMapInfo(Map<String, Set<SchemaElement>> mapResultRowSet, SchemaMapInfo schemaMap,
SchemaElementType schemaElementType) {
if (Objects.isNull(mapResultRowSet) || mapResultRowSet.size() <= 0) {
return;
}
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
for (Map.Entry<String, Set<SchemaElement>> entry : mapResultRowSet.entrySet()) {
String detectWord = entry.getKey();
Set<SchemaElement> schemaElements = entry.getValue();
for (SchemaElement schemaElement : schemaElements) {
Set<Long> regElementSet = getRegElementSet(schemaMap, schemaElementType, schemaElement);
if (regElementSet.contains(schemaElement.getId())) {
continue;
}
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement)
.word(schemaElement.getName())
.detectWord(detectWord)
.frequency(10000L)
.similarity(mapperHelper.getSimilarity(detectWord, schemaElement.getName()))
.build();
log.info("schemaElementType:{},add to schema, elementMatch {}", schemaElementType, schemaElementMatch);
addToSchemaMap(schemaMap, schemaElement.getModel(), schemaElementMatch);
for (FuzzyResult match : matches) {
SchemaElement schemaElement = match.getSchemaElement();
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
if (regElementSet.contains(schemaElement.getId())) {
continue;
}
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement)
.word(schemaElement.getName())
.detectWord(match.getDetectWord())
.frequency(10000L)
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
}
}
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElementType schemaElementType,
SchemaElement schemaElement) {
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
if (CollectionUtils.isEmpty(elements)) {
return new HashSet<>();
}
return elements.stream()
.filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType()))
.filter(elementMatch ->
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.getElement().getId())
.collect(Collectors.toSet());
}

View File

@@ -13,7 +13,6 @@ import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -33,11 +32,7 @@ public class HanlpDictMapper extends BaseMapper {
HanlpMatchStrategy matchStrategy = ContextUtils.getBean(HanlpMatchStrategy.class);
Set<Long> detectModelIds = getModelIds(queryContext);
terms = filterByModelIds(terms, detectModelIds);
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds);
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext, terms);
HanlpHelper.transLetterOriginal(matches);

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants;
@@ -33,7 +34,9 @@ public class HanlpMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
private OptimizationConfig optimizationConfig;
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
Set<Long> detectModelIds) {
QueryReq queryReq = queryContext.getRequest();
String text = queryReq.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
@@ -41,7 +44,7 @@ public class HanlpMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
List<HanlpMapResult> detects = detect(queryReq, terms, detectModelIds);
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
@@ -54,8 +57,9 @@ public class HanlpMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
}
public void detectByStep(QueryReq queryReq, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) {
QueryReq queryReq = queryContext.getRequest();
String text = queryReq.getQueryText();
Integer agentId = queryReq.getAgentId();
String detectSegment = text.substring(startIndex, index);

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -11,6 +11,6 @@ import java.util.Set;
*/
public interface MatchStrategy<T> {
Map<MatchText, List<T>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelId);
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
@@ -25,8 +26,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
private static final int SEARCH_SIZE = 3;
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryReq queryReq, List<Term> originals,
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
Set<Long> detectModelIds) {
QueryReq queryReq = queryContext.getRequest();
String text = queryReq.getQueryText();
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
@@ -87,7 +89,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
@Override
public void detectByStep(QueryReq queryReq, Set<HanlpMapResult> results, Set<Long> detectModelIds,
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
Integer startIndex,
Integer i, int offset) {

View File

@@ -4,6 +4,7 @@ import com.github.benmanes.caffeine.cache.Cache;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
@@ -94,8 +95,10 @@ public class SearchServiceImpl implements SearchService {
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq);
QueryContext queryContext = new QueryContext();
queryContext.setRequest(queryReq);
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryReq, originals, detectModelIds);
searchMatchStrategy.match(queryContext, originals, detectModelIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
// 4.get the most matching data

View File

@@ -0,0 +1,30 @@
package com.tencent.supersonic.knowledge.dictionary;
import com.google.common.base.Objects;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class FuzzyResult extends MapResult {
private SchemaElement schemaElement;
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FuzzyResult that = (FuzzyResult) o;
return Objects.equal(name, that.name) && Objects.equal(schemaElement, that.schemaElement);
}
@Override
public int hashCode() {
return Objects.hashCode(name, schemaElement);
}
}