mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless) Adjust the similarity threshold for vector retrieval (#887)
This commit is contained in:
@@ -74,8 +74,8 @@ public class SysParameter {
|
||||
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.number", "5",
|
||||
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.distance.threshold",
|
||||
"0.01", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.threshold",
|
||||
"0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"));
|
||||
|
||||
//parser config
|
||||
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
|
||||
|
||||
@@ -49,13 +49,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
Set<String> detectSegments) {
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
@@ -73,9 +73,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
List<String> queryTextsSub, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperThreshold();
|
||||
// step1. build query params
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
@@ -94,7 +94,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
||||
return retrieval.getDistance() > distance.doubleValue();
|
||||
return retrieval.getDistance() > 1 - distance.doubleValue();
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
@@ -66,7 +66,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
selectFields.addAll(whereFields);
|
||||
List<String> selectWhereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
|
||||
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
@@ -52,8 +52,8 @@ public class OptimizationConfig {
|
||||
@Value("${embedding.mapper.round.number:10}")
|
||||
private int embeddingMapperRoundNumber;
|
||||
|
||||
@Value("${embedding.mapper.distance.threshold:0.01}")
|
||||
private Double embeddingMapperDistanceThreshold;
|
||||
@Value("${embedding.mapper.threshold:0.99}")
|
||||
private Double embeddingMapperThreshold;
|
||||
|
||||
@Value("${s2SQL.linking.value.switch:true}")
|
||||
private boolean useLinkingValueSwitch;
|
||||
@@ -135,8 +135,8 @@ public class OptimizationConfig {
|
||||
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperDistanceThreshold() {
|
||||
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
|
||||
public Double getEmbeddingMapperThreshold() {
|
||||
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseLinkingValueSwitch() {
|
||||
|
||||
Reference in New Issue
Block a user