(improvement)(headless) Adjust the similarity threshold for vector retrieval (#887)

This commit is contained in:
lexluo09
2024-04-06 10:03:52 +08:00
committed by GitHub
parent 3ef3c44277
commit 0577090b39
4 changed files with 12 additions and 12 deletions

View File

@@ -74,8 +74,8 @@ public class SysParameter {
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置")); "批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.number", "5", parameters.add(new Parameter("embedding.mapper.number", "5",
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置")); "批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.distance.threshold", parameters.add(new Parameter("embedding.mapper.threshold",
"0.01", "向量召回相似度阈值", "相似度于该阈值的则舍弃", "number", "Mapper相关配置")); "0.99", "向量召回相似度阈值", "相似度于该阈值的则舍弃", "number", "Mapper相关配置"));
//parser config //parser config
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT", Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",

View File

@@ -49,13 +49,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Override @Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds, public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) { String detectSegment, int offset) {
} }
@Override @Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds, protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
Set<String> detectSegments) { Set<String> detectSegments) {
List<String> queryTextsList = detectSegments.stream() List<String> queryTextsList = detectSegments.stream()
.map(detectSegment -> detectSegment.trim()) .map(detectSegment -> detectSegment.trim())
@@ -73,9 +73,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
} }
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds, private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
List<String> queryTextsSub, Map<Long, List<Long>> modelIdToDataSetIds) { List<String> queryTextsSub, Map<Long, List<Long>> modelIdToDataSetIds) {
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber(); int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold(); Double distance = optimizationConfig.getEmbeddingMapperThreshold();
// step1. build query params // step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
@@ -94,7 +94,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
if (CollectionUtils.isNotEmpty(retrievals)) { if (CollectionUtils.isNotEmpty(retrievals)) {
retrievals.removeIf(retrieval -> { retrievals.removeIf(retrieval -> {
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) { if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
return retrieval.getDistance() > distance.doubleValue(); return retrieval.getDistance() > 1 - distance.doubleValue();
} }
return false; return false;
}); });

View File

@@ -66,7 +66,7 @@ public class QueryTypeParser implements SemanticParser {
} }
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
selectFields.addAll(whereFields); selectFields.addAll(whereFields);
List<String> selectWhereFilterByTimeFields = filterByTimeFields(whereFields); List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) { if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName) Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());

View File

@@ -52,8 +52,8 @@ public class OptimizationConfig {
@Value("${embedding.mapper.round.number:10}") @Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber; private int embeddingMapperRoundNumber;
@Value("${embedding.mapper.distance.threshold:0.01}") @Value("${embedding.mapper.threshold:0.99}")
private Double embeddingMapperDistanceThreshold; private Double embeddingMapperThreshold;
@Value("${s2SQL.linking.value.switch:true}") @Value("${s2SQL.linking.value.switch:true}")
private boolean useLinkingValueSwitch; private boolean useLinkingValueSwitch;
@@ -135,8 +135,8 @@ public class OptimizationConfig {
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber); return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
} }
public Double getEmbeddingMapperDistanceThreshold() { public Double getEmbeddingMapperThreshold() {
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold); return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
} }
public boolean isUseLinkingValueSwitch() { public boolean isUseLinkingValueSwitch() {