mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
(improvement)(Headless) Add STRICT, MODERATE, and LOOSE modes in the mapper phase. (#900)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||
@@ -9,6 +10,7 @@ import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -29,7 +31,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -69,8 +71,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
return;
|
||||
Set<String> detectSegments) {
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||
@@ -152,6 +153,11 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset);
|
||||
String detectSegment, int offset);
|
||||
|
||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||
double decreaseAmount = (threshold - minThreshold) / 4;
|
||||
double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount;
|
||||
return divideThreshold >= minThreshold ? divideThreshold : minThreshold;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
String detectSegment, int offset) {
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
@@ -94,22 +94,20 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
}
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||
|
||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||
: metricDimensionMinThresholdConfig;
|
||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||
modelElementMatches, metricDimensionThresholdConfig);
|
||||
threshold = threshold / 2;
|
||||
log.info("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
||||
modelElementMatches, threshold);
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
|
||||
@@ -68,16 +68,18 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
optimizationConfig.getEmbeddingMapperBatch());
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext.getModelIdToDataSetIds());
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
|
||||
}
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperThreshold();
|
||||
// step1. build query params
|
||||
List<String> queryTextsSub, QueryContext queryContext) {
|
||||
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
double threshold = getThreshold(optimizationConfig.getEmbeddingMapperThreshold(),
|
||||
optimizationConfig.getEmbeddingMapperMinThreshold(), queryContext.getMapModeEnum());
|
||||
// step1. build query params
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
|
||||
@@ -94,7 +96,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
||||
return retrieval.getDistance() > 1 - distance.doubleValue();
|
||||
return retrieval.getDistance() > 1 - threshold;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
@@ -39,7 +40,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -61,7 +62,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
String detectSegment, int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
@@ -84,7 +85,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
// step4. filter by similarity
|
||||
hanlpMapResults = hanlpMapResults.stream()
|
||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||
>= getThresholdMatch(term.getNatures(), queryContext))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
@@ -118,4 +119,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
public String getMapKey(HanlpMapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
|
||||
public double getThresholdMatch(List<String> natures, QueryContext queryContext) {
|
||||
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
if (mapperHelper.existDimensionValues(natures)) {
|
||||
threshold = optimizationConfig.getDimensionValueThresholdConfig();
|
||||
minThreshold = optimizationConfig.getDimensionValueMinThresholdConfig();
|
||||
}
|
||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,13 +45,6 @@ public class MapperHelper {
|
||||
return index;
|
||||
}
|
||||
|
||||
public double getThresholdMatch(List<String> natures) {
|
||||
if (existDimensionValues(natures)) {
|
||||
return optimizationConfig.getDimensionValueThresholdConfig();
|
||||
}
|
||||
return optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
}
|
||||
|
||||
/***
|
||||
* exist dimension values
|
||||
* @param natures
|
||||
|
||||
@@ -26,6 +26,9 @@ public class OptimizationConfig {
|
||||
@Value("${metric.dimension.threshold:0.3}")
|
||||
private Double metricDimensionThresholdConfig;
|
||||
|
||||
@Value("${dimension.value.min.threshold:0.2}")
|
||||
private Double dimensionValueMinThresholdConfig;
|
||||
|
||||
@Value("${dimension.value.threshold:0.5}")
|
||||
private Double dimensionValueThresholdConfig;
|
||||
|
||||
@@ -52,6 +55,9 @@ public class OptimizationConfig {
|
||||
@Value("${embedding.mapper.round.number:10}")
|
||||
private int embeddingMapperRoundNumber;
|
||||
|
||||
@Value("${embedding.mapper.min.threshold:0.6}")
|
||||
private Double embeddingMapperMinThreshold;
|
||||
|
||||
@Value("${embedding.mapper.threshold:0.99}")
|
||||
private Double embeddingMapperThreshold;
|
||||
|
||||
@@ -95,6 +101,10 @@ public class OptimizationConfig {
|
||||
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueMinThresholdConfig() {
|
||||
return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueThresholdConfig() {
|
||||
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
}
|
||||
@@ -135,6 +145,10 @@ public class OptimizationConfig {
|
||||
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperMinThreshold() {
|
||||
return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperThreshold() {
|
||||
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
@@ -37,6 +38,7 @@ public class QueryContext {
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user