mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Perform a complete refactoring of the mapper and integrate the fuzzyNameMapper into the BaseMatchStrategy (#328)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
@@ -10,15 +9,12 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
@@ -57,35 +53,6 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
schemaElementMatches.add(schemaElementMatch);
|
||||
}
|
||||
|
||||
public Set<Long> getModelIds(QueryContext queryContext) {
|
||||
return ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest());
|
||||
}
|
||||
|
||||
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return detectModelIds.contains(modelId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
return terms;
|
||||
}
|
||||
|
||||
public void logTerms(List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return;
|
||||
}
|
||||
for (Term term : terms) {
|
||||
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -28,26 +29,25 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||
String text = queryReq.getQueryText();
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
|
||||
List<T> detects = detect(queryReq, terms, detectModelIds);
|
||||
List<T> detects = detect(queryContext, terms, detectModelIds);
|
||||
Map<MatchText, List<T>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<T> detect(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = queryReq.getQueryText();
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
|
||||
for (Integer index = 0; index <= text.length() - 1; ) {
|
||||
@@ -56,7 +56,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
int offset = mapperHelper.getStepOffset(terms, index);
|
||||
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
||||
if (i <= text.length()) {
|
||||
detectByStep(queryReq, results, detectModelIds, index, i, offset);
|
||||
detectByStep(queryContext, results, detectModelIds, index, i, offset);
|
||||
}
|
||||
}
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
@@ -94,8 +94,10 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||
Map<MatchText, List<T>> matchResult = match(queryReq, terms, detectModelIds);
|
||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
@@ -110,12 +112,37 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return detectModelIds.contains(modelId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
return terms;
|
||||
}
|
||||
|
||||
public void logTerms(List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return;
|
||||
}
|
||||
for (Term term : terms) {
|
||||
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
}
|
||||
}
|
||||
|
||||
public abstract boolean needDelete(T oneRoundResult, T existResult);
|
||||
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryReq queryReq, Set<T> results, Set<Long> detectModelIds, Integer startIndex,
|
||||
Integer index, int offset);
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@@ -29,12 +28,7 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
|
||||
Set<Long> detectModelIds = getModelIds(queryContext);
|
||||
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
|
||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds);
|
||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matchResults);
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
@@ -46,8 +47,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryReq queryReq, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String detectSegment = queryReq.getQueryText().substring(startIndex, index);
|
||||
// step1. build query params
|
||||
if (StringUtils.isBlank(detectSegment)
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Fuzzy Match Strategy
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class FuzzyMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<FuzzyResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
this.allElements = getSchemaElements();
|
||||
return super.match(queryContext, terms, detectModelIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(FuzzyResult oneRoundResult, FuzzyResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(FuzzyResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<FuzzyResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
|
||||
// step1. build query params
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
if (!name.contains(detectSegment)
|
||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
FuzzyResult fuzzyResult = new FuzzyResult();
|
||||
fuzzyResult.setDetectWord(detectSegment);
|
||||
fuzzyResult.setName(schemaElement.getName());
|
||||
fuzzyResult.setSchemaElement(schemaElement);
|
||||
existResults.add(fuzzyResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<SchemaElement> getSchemaElements() {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
|
||||
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
|
||||
return allElements;
|
||||
}
|
||||
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||
|
||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||
: metricDimensionMinThresholdConfig;
|
||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||
modelElementMatches, metricDimensionThresholdConfig);
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
return models.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getName, a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
}, (k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -6,18 +6,11 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -34,145 +27,39 @@ public class FuzzyNameMapper extends BaseMapper {
|
||||
|
||||
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
detectAndAddToSchema(queryContext, terms, semanticSchema.getDimensions(), SchemaElementType.DIMENSION);
|
||||
|
||||
detectAndAddToSchema(queryContext, terms, semanticSchema.getMetrics(), SchemaElementType.METRIC);
|
||||
|
||||
}
|
||||
|
||||
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> models,
|
||||
SchemaElementType schemaElementType) {
|
||||
try {
|
||||
|
||||
Map<String, Set<SchemaElement>> modelResultSet = getResultSet(queryContext, terms, models);
|
||||
|
||||
addToSchemaMapInfo(modelResultSet, queryContext.getMapInfo(), schemaElementType);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("detectAndAddToSchema error", e);
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
|
||||
List<SchemaElement> models) {
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
FuzzyMatchStrategy fuzzyMatchStrategy = ContextUtils.getBean(FuzzyMatchStrategy.class);
|
||||
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
List<FuzzyResult> matches = fuzzyMatchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(models);
|
||||
|
||||
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
|
||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||
|
||||
Map<String, Set<SchemaElement>> modelResultSet = new HashMap<>();
|
||||
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
|
||||
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
|
||||
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
|
||||
if (endIndex > queryText.length()) {
|
||||
continue;
|
||||
}
|
||||
String detectSegment = queryText.substring(startIndex, endIndex);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!name.contains(detectSegment)
|
||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||
continue;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
Set<SchemaElement> preSchemaElements = modelResultSet.putIfAbsent(detectSegment, schemaElements);
|
||||
if (Objects.nonNull(preSchemaElements)) {
|
||||
preSchemaElements.addAll(schemaElements);
|
||||
}
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
return modelResultSet;
|
||||
}
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
|
||||
.getModelElementMatches();
|
||||
boolean existElement = modelElementMatches.entrySet().stream()
|
||||
.anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||
|
||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||
: metricDimensionMinThresholdConfig;
|
||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||
modelElementMatches, metricDimensionThresholdConfig);
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
return models.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getName, a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
}, (k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
}
|
||||
|
||||
private void addToSchemaMapInfo(Map<String, Set<SchemaElement>> mapResultRowSet, SchemaMapInfo schemaMap,
|
||||
SchemaElementType schemaElementType) {
|
||||
if (Objects.isNull(mapResultRowSet) || mapResultRowSet.size() <= 0) {
|
||||
return;
|
||||
}
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
|
||||
for (Map.Entry<String, Set<SchemaElement>> entry : mapResultRowSet.entrySet()) {
|
||||
String detectWord = entry.getKey();
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
|
||||
Set<Long> regElementSet = getRegElementSet(schemaMap, schemaElementType, schemaElement);
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(detectWord)
|
||||
.frequency(10000L)
|
||||
.similarity(mapperHelper.getSimilarity(detectWord, schemaElement.getName()))
|
||||
.build();
|
||||
log.info("schemaElementType:{},add to schema, elementMatch {}", schemaElementType, schemaElementMatch);
|
||||
addToSchemaMap(schemaMap, schemaElement.getModel(), schemaElementMatch);
|
||||
for (FuzzyResult match : matches) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(10000L)
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElementType schemaElementType,
|
||||
SchemaElement schemaElement) {
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return elements.stream()
|
||||
.filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType()))
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -33,11 +32,7 @@ public class HanlpDictMapper extends BaseMapper {
|
||||
|
||||
HanlpMatchStrategy matchStrategy = ContextUtils.getBean(HanlpMatchStrategy.class);
|
||||
|
||||
Set<Long> detectModelIds = getModelIds(queryContext);
|
||||
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
|
||||
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds);
|
||||
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matches);
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
@@ -33,7 +34,9 @@ public class HanlpMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -41,7 +44,7 @@ public class HanlpMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
|
||||
List<HanlpMapResult> detects = detect(queryReq, terms, detectModelIds);
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
|
||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
@@ -54,8 +57,9 @@ public class HanlpMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryReq queryReq, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
Integer agentId = queryReq.getAgentId();
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@@ -11,6 +11,6 @@ import java.util.Set;
|
||||
*/
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
Map<MatchText, List<T>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelId);
|
||||
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
@@ -25,8 +26,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
private static final int SEARCH_SIZE = 3;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryReq queryReq, List<Term> originals,
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
|
||||
Set<Long> detectModelIds) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
@@ -87,7 +89,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryReq queryReq, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
||||
Integer startIndex,
|
||||
Integer i, int offset) {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.github.benmanes.caffeine.cache.Cache;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
@@ -94,8 +95,10 @@ public class SearchServiceImpl implements SearchService {
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq);
|
||||
|
||||
QueryContext queryContext = new QueryContext();
|
||||
queryContext.setRequest(queryReq);
|
||||
Map<MatchText, List<HanlpMapResult>> regTextMap =
|
||||
searchMatchStrategy.match(queryReq, originals, detectModelIds);
|
||||
searchMatchStrategy.match(queryContext, originals, detectModelIds);
|
||||
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
||||
|
||||
// 4.get the most matching data
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.knowledge.dictionary;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class FuzzyResult extends MapResult {
|
||||
|
||||
private SchemaElement schemaElement;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
FuzzyResult that = (FuzzyResult) o;
|
||||
return Objects.equal(name, that.name) && Objects.equal(schemaElement, that.schemaElement);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(name, schemaElement);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user