mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(Headless) Add STRICT, MODERATE, and LOOSE modes in the mapper phase. (#900)
This commit is contained in:
@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -57,12 +58,14 @@ public class SysParameter {
|
||||
parameters.add(new Parameter("metric.dimension.threshold", "0.3",
|
||||
"指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("metric.dimension.min.threshold", "0.3",
|
||||
"指标名、维度名最小文本相似度阈值",
|
||||
"最小编辑距离阈值, 在FuzzyNameMapper中, 如果上面设定的编辑距离阈值的1/2大于该最小编辑距离, 则取上面设定阈值的1/2作为阈值, 否则取该阈值",
|
||||
parameters.add(new Parameter("metric.dimension.min.threshold", "0.25",
|
||||
"指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("dimension.value.threshold", "0.5",
|
||||
"维度值最小文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("dimension.value.min.threshold", "0.3",
|
||||
"维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置"));
|
||||
|
||||
//embedding mapper config
|
||||
@@ -76,6 +79,8 @@ public class SysParameter {
|
||||
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.threshold",
|
||||
"0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.min.threshold",
|
||||
"0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"));
|
||||
|
||||
//parser config
|
||||
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||
|
||||
public enum MapModeEnum {
|
||||
STRICT,
|
||||
MODERATE,
|
||||
LOOSE;
|
||||
STRICT(0),
|
||||
MODERATE(2),
|
||||
LOOSE(4);
|
||||
public int threshold;
|
||||
|
||||
MapModeEnum(Integer threshold) {
|
||||
this.threshold = threshold;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,5 +14,5 @@ public class QueryMapReq {
|
||||
private List<String> dataSetNames;
|
||||
private User user;
|
||||
private Integer topN = 10;
|
||||
private MapModeEnum mapModeEnum;
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Set;
|
||||
@@ -17,5 +18,6 @@ public class QueryReq {
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||
@@ -9,6 +10,7 @@ import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -29,7 +31,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -69,8 +71,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
return;
|
||||
Set<String> detectSegments) {
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||
@@ -152,6 +153,11 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset);
|
||||
String detectSegment, int offset);
|
||||
|
||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||
double decreaseAmount = (threshold - minThreshold) / 4;
|
||||
double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount;
|
||||
return divideThreshold >= minThreshold ? divideThreshold : minThreshold;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
String detectSegment, int offset) {
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
@@ -94,22 +94,20 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
}
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||
|
||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||
: metricDimensionMinThresholdConfig;
|
||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||
modelElementMatches, metricDimensionThresholdConfig);
|
||||
threshold = threshold / 2;
|
||||
log.info("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
||||
modelElementMatches, threshold);
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
|
||||
@@ -68,16 +68,18 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
optimizationConfig.getEmbeddingMapperBatch());
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext.getModelIdToDataSetIds());
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
|
||||
}
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperThreshold();
|
||||
// step1. build query params
|
||||
List<String> queryTextsSub, QueryContext queryContext) {
|
||||
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
double threshold = getThreshold(optimizationConfig.getEmbeddingMapperThreshold(),
|
||||
optimizationConfig.getEmbeddingMapperMinThreshold(), queryContext.getMapModeEnum());
|
||||
// step1. build query params
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
|
||||
@@ -94,7 +96,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
||||
return retrieval.getDistance() > 1 - distance.doubleValue();
|
||||
return retrieval.getDistance() > 1 - threshold;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
@@ -39,7 +40,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -61,7 +62,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
String detectSegment, int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
@@ -84,7 +85,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
// step4. filter by similarity
|
||||
hanlpMapResults = hanlpMapResults.stream()
|
||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||
>= getThresholdMatch(term.getNatures(), queryContext))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
@@ -118,4 +119,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
public String getMapKey(HanlpMapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
|
||||
public double getThresholdMatch(List<String> natures, QueryContext queryContext) {
|
||||
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
if (mapperHelper.existDimensionValues(natures)) {
|
||||
threshold = optimizationConfig.getDimensionValueThresholdConfig();
|
||||
minThreshold = optimizationConfig.getDimensionValueMinThresholdConfig();
|
||||
}
|
||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,13 +45,6 @@ public class MapperHelper {
|
||||
return index;
|
||||
}
|
||||
|
||||
public double getThresholdMatch(List<String> natures) {
|
||||
if (existDimensionValues(natures)) {
|
||||
return optimizationConfig.getDimensionValueThresholdConfig();
|
||||
}
|
||||
return optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
}
|
||||
|
||||
/***
|
||||
* exist dimension values
|
||||
* @param natures
|
||||
|
||||
@@ -26,6 +26,9 @@ public class OptimizationConfig {
|
||||
@Value("${metric.dimension.threshold:0.3}")
|
||||
private Double metricDimensionThresholdConfig;
|
||||
|
||||
@Value("${dimension.value.min.threshold:0.2}")
|
||||
private Double dimensionValueMinThresholdConfig;
|
||||
|
||||
@Value("${dimension.value.threshold:0.5}")
|
||||
private Double dimensionValueThresholdConfig;
|
||||
|
||||
@@ -52,6 +55,9 @@ public class OptimizationConfig {
|
||||
@Value("${embedding.mapper.round.number:10}")
|
||||
private int embeddingMapperRoundNumber;
|
||||
|
||||
@Value("${embedding.mapper.min.threshold:0.6}")
|
||||
private Double embeddingMapperMinThreshold;
|
||||
|
||||
@Value("${embedding.mapper.threshold:0.99}")
|
||||
private Double embeddingMapperThreshold;
|
||||
|
||||
@@ -95,6 +101,10 @@ public class OptimizationConfig {
|
||||
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueMinThresholdConfig() {
|
||||
return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueThresholdConfig() {
|
||||
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
}
|
||||
@@ -135,6 +145,10 @@ public class OptimizationConfig {
|
||||
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperMinThreshold() {
|
||||
return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperThreshold() {
|
||||
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
@@ -37,6 +38,7 @@ public class QueryContext {
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/semantic/query")
|
||||
@RequestMapping("/api/semantic/meta")
|
||||
@Slf4j
|
||||
public class MetaDiscoveryApiController {
|
||||
|
||||
|
||||
@@ -179,6 +179,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
.mapInfo(new SchemaMapInfo())
|
||||
.modelIdToDataSetIds(modelIdToDataSetIds)
|
||||
.text2SQLType(queryReq.getText2SQLType())
|
||||
.mapModeEnum(queryReq.getMapModeEnum())
|
||||
.build();
|
||||
BeanUtils.copyProperties(queryReq, queryCtx);
|
||||
return queryCtx;
|
||||
|
||||
Reference in New Issue
Block a user