diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java index 7fb3e3d2e..1ed91eddf 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java @@ -74,8 +74,8 @@ public class SysParameter { "批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置")); parameters.add(new Parameter("embedding.mapper.number", "5", "批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置")); - parameters.add(new Parameter("embedding.mapper.distance.threshold", - "0.01", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置")); + parameters.add(new Parameter("embedding.mapper.threshold", + "0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置")); //parser config Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT", diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java index 1bc4acc51..3fbe843be 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java @@ -49,13 +49,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { @Override public void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset) { + String detectSegment, int offset) { } @Override protected void detectByBatch(QueryContext queryContext, Set results, Set detectDataSetIds, - Set detectSegments) { + Set detectSegments) { List queryTextsList = detectSegments.stream() .map(detectSegment -> detectSegment.trim()) @@ -73,9 +73,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { } private void detectByQueryTextsSub(Set results, Set detectDataSetIds, - List queryTextsSub, Map> modelIdToDataSetIds) { + List queryTextsSub, Map> modelIdToDataSetIds) { int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber(); - Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold(); + Double distance = optimizationConfig.getEmbeddingMapperThreshold(); // step1. build query params RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); @@ -94,7 +94,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { if (CollectionUtils.isNotEmpty(retrievals)) { retrievals.removeIf(retrieval -> { if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) { - return retrieval.getDistance() > distance.doubleValue(); + return retrieval.getDistance() > 1 - distance.doubleValue(); } return false; }); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/QueryTypeParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/QueryTypeParser.java index 2a86b69b0..94807e55b 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/QueryTypeParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/QueryTypeParser.java @@ -66,7 +66,7 @@ public class QueryTypeParser implements SemanticParser { } List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); selectFields.addAll(whereFields); - List selectWhereFilterByTimeFields = filterByTimeFields(whereFields); + List selectWhereFilterByTimeFields = filterByTimeFields(selectFields); if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) { Set tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName) .collect(Collectors.toSet()); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java index 5871d8fc6..f1c8f3e83 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java @@ -52,8 +52,8 @@ public class OptimizationConfig { @Value("${embedding.mapper.round.number:10}") private int embeddingMapperRoundNumber; - @Value("${embedding.mapper.distance.threshold:0.01}") - private Double embeddingMapperDistanceThreshold; + @Value("${embedding.mapper.threshold:0.99}") + private Double embeddingMapperThreshold; @Value("${s2SQL.linking.value.switch:true}") private boolean useLinkingValueSwitch; @@ -135,8 +135,8 @@ public class OptimizationConfig { return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber); } - public Double getEmbeddingMapperDistanceThreshold() { - return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold); + public Double getEmbeddingMapperThreshold() { + return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold); } public boolean isUseLinkingValueSwitch() {