mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless) Adjust the similarity threshold for vector retrieval (#887)
This commit is contained in:
@@ -74,8 +74,8 @@ public class SysParameter {
|
|||||||
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
|
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
|
||||||
parameters.add(new Parameter("embedding.mapper.number", "5",
|
parameters.add(new Parameter("embedding.mapper.number", "5",
|
||||||
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
|
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
|
||||||
parameters.add(new Parameter("embedding.mapper.distance.threshold",
|
parameters.add(new Parameter("embedding.mapper.threshold",
|
||||||
"0.01", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));
|
"0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"));
|
||||||
|
|
||||||
//parser config
|
//parser config
|
||||||
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
|
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
|
||||||
|
|||||||
@@ -49,13 +49,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
|
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
|
||||||
String detectSegment, int offset) {
|
String detectSegment, int offset) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
|
|
||||||
List<String> queryTextsList = detectSegments.stream()
|
List<String> queryTextsList = detectSegments.stream()
|
||||||
.map(detectSegment -> detectSegment.trim())
|
.map(detectSegment -> detectSegment.trim())
|
||||||
@@ -73,9 +73,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||||
List<String> queryTextsSub, Map<Long, List<Long>> modelIdToDataSetIds) {
|
List<String> queryTextsSub, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
Double distance = optimizationConfig.getEmbeddingMapperThreshold();
|
||||||
// step1. build query params
|
// step1. build query params
|
||||||
|
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||||
@@ -94,7 +94,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||||
retrievals.removeIf(retrieval -> {
|
retrievals.removeIf(retrieval -> {
|
||||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
||||||
return retrieval.getDistance() > distance.doubleValue();
|
return retrieval.getDistance() > 1 - distance.doubleValue();
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
selectFields.addAll(whereFields);
|
selectFields.addAll(whereFields);
|
||||||
List<String> selectWhereFilterByTimeFields = filterByTimeFields(whereFields);
|
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
|
||||||
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
||||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
|
|||||||
@@ -52,8 +52,8 @@ public class OptimizationConfig {
|
|||||||
@Value("${embedding.mapper.round.number:10}")
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
private int embeddingMapperRoundNumber;
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
@Value("${embedding.mapper.distance.threshold:0.01}")
|
@Value("${embedding.mapper.threshold:0.99}")
|
||||||
private Double embeddingMapperDistanceThreshold;
|
private Double embeddingMapperThreshold;
|
||||||
|
|
||||||
@Value("${s2SQL.linking.value.switch:true}")
|
@Value("${s2SQL.linking.value.switch:true}")
|
||||||
private boolean useLinkingValueSwitch;
|
private boolean useLinkingValueSwitch;
|
||||||
@@ -135,8 +135,8 @@ public class OptimizationConfig {
|
|||||||
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Double getEmbeddingMapperDistanceThreshold() {
|
public Double getEmbeddingMapperThreshold() {
|
||||||
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
|
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isUseLinkingValueSwitch() {
|
public boolean isUseLinkingValueSwitch() {
|
||||||
|
|||||||
Reference in New Issue
Block a user