diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java index 1ed91eddf..34979e691 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java @@ -4,6 +4,7 @@ import com.google.common.collect.Lists; import lombok.Data; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; + import java.util.Arrays; import java.util.List; import java.util.Map; @@ -57,12 +58,14 @@ public class SysParameter { parameters.add(new Parameter("metric.dimension.threshold", "0.3", "指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "number", "Mapper相关配置")); - parameters.add(new Parameter("metric.dimension.min.threshold", "0.3", - "指标名、维度名最小文本相似度阈值", - "最小编辑距离阈值, 在FuzzyNameMapper中, 如果上面设定的编辑距离阈值的1/2大于该最小编辑距离, 则取上面设定阈值的1/2作为阈值, 否则取该阈值", + parameters.add(new Parameter("metric.dimension.min.threshold", "0.25", + "指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值", "number", "Mapper相关配置")); parameters.add(new Parameter("dimension.value.threshold", "0.5", - "维度值最小文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", + "维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", + "number", "Mapper相关配置")); + parameters.add(new Parameter("dimension.value.min.threshold", "0.3", + "维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置")); //embedding mapper config @@ -76,6 +79,8 @@ public class SysParameter { "批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置")); parameters.add(new Parameter("embedding.mapper.threshold", "0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置")); + parameters.add(new Parameter("embedding.mapper.min.threshold", + "0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置")); //parser config Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT", diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MapModeEnum.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MapModeEnum.java index 7551dd202..33e0aa156 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MapModeEnum.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MapModeEnum.java @@ -1,7 +1,12 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum MapModeEnum { - STRICT, - MODERATE, - LOOSE; + STRICT(0), + MODERATE(2), + LOOSE(4); + public int threshold; + + MapModeEnum(Integer threshold) { + this.threshold = threshold; + } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryMapReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryMapReq.java index 2280e2f77..6b53b06bf 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryMapReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryMapReq.java @@ -14,5 +14,5 @@ public class QueryMapReq { private List dataSetNames; private User user; private Integer topN = 10; - private MapModeEnum mapModeEnum; + private MapModeEnum mapModeEnum = MapModeEnum.STRICT; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java index dea2fc0cf..6ee1eb61b 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java @@ -4,6 +4,7 @@ import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import lombok.Data; import java.util.Set; @@ -17,5 +18,6 @@ public class QueryReq { private QueryFilters queryFilters; private boolean saveAnswer = true; private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; + private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private SchemaMapInfo mapInfo = new SchemaMapInfo(); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMatchStrategy.java index 2a2978aea..3bf8151c1 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMatchStrategy.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.core.chat.mapper; +import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper; @@ -9,6 +10,7 @@ import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; + import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -29,7 +31,7 @@ public abstract class BaseMatchStrategy implements MatchStrategy { @Override public Map> match(QueryContext queryContext, List terms, - Set detectDataSetIds) { + Set detectDataSetIds) { String text = queryContext.getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; @@ -69,8 +71,7 @@ public abstract class BaseMatchStrategy implements MatchStrategy { } protected void detectByBatch(QueryContext queryContext, Set results, Set detectDataSetIds, - Set detectSegments) { - return; + Set detectSegments) { } public Map getRegOffsetToLength(List terms) { @@ -152,6 +153,11 @@ public abstract class BaseMatchStrategy implements MatchStrategy { public abstract String getMapKey(T a); public abstract void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset); + String detectSegment, int offset); + public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) { + double decreaseAmount = (threshold - minThreshold) / 4; + double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount; + return divideThreshold >= minThreshold ? divideThreshold : minThreshold; + } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/DatabaseMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/DatabaseMatchStrategy.java index 550654384..965a29d28 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/DatabaseMatchStrategy.java @@ -56,7 +56,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy } public void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset) { + String detectSegment, int offset) { if (StringUtils.isBlank(detectSegment)) { return; } @@ -94,22 +94,20 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy } private Double getThreshold(QueryContext queryContext) { - Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig(); - Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig(); + + Double threshold = optimizationConfig.getMetricDimensionThresholdConfig(); + Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig(); Map> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches(); boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1); if (!existElement) { - double halfThreshold = metricDimensionThresholdConfig / 2; - - metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold - : metricDimensionMinThresholdConfig; - log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}", - modelElementMatches, metricDimensionThresholdConfig); + threshold = threshold / 2; + log.info("ModelElementMatches:{},not exist Element threshold reduce by half:{}", + modelElementMatches, threshold); } - return metricDimensionThresholdConfig; + return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum()); } private Map> getNameToItems(List models) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java index 3fbe843be..40a8ddcf4 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java @@ -68,16 +68,18 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { optimizationConfig.getEmbeddingMapperBatch()); for (List queryTextsSub : queryTextsSubList) { - detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext.getModelIdToDataSetIds()); + detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext); } } private void detectByQueryTextsSub(Set results, Set detectDataSetIds, - List queryTextsSub, Map> modelIdToDataSetIds) { - int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber(); - Double distance = optimizationConfig.getEmbeddingMapperThreshold(); - // step1. build query params + List queryTextsSub, QueryContext queryContext) { + Map> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds(); + int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber(); + double threshold = getThreshold(optimizationConfig.getEmbeddingMapperThreshold(), + optimizationConfig.getEmbeddingMapperMinThreshold(), queryContext.getMapModeEnum()); + // step1. build query params RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); // step2. retrieveQuery by detectSegment @@ -94,7 +96,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { if (CollectionUtils.isNotEmpty(retrievals)) { retrievals.removeIf(retrieval -> { if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) { - return retrieval.getDistance() > 1 - distance.doubleValue(); + return retrieval.getDistance() > 1 - threshold; } return false; }); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java index 7aab4964d..44184ee78 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java @@ -11,6 +11,7 @@ import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; + import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; @@ -39,7 +40,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { @Override public Map> match(QueryContext queryContext, List terms, - Set detectDataSetIds) { + Set detectDataSetIds) { String text = queryContext.getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; @@ -61,7 +62,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { } public void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset) { + String detectSegment, int offset) { // step1. pre search Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize(); LinkedHashSet hanlpMapResults = knowledgeService.prefixSearch(detectSegment, @@ -84,7 +85,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { // step4. filter by similarity hanlpMapResults = hanlpMapResults.stream() .filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName()) - >= mapperHelper.getThresholdMatch(term.getNatures())) + >= getThresholdMatch(term.getNatures(), queryContext)) .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) .collect(Collectors.toCollection(LinkedHashSet::new)); @@ -118,4 +119,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { public String getMapKey(HanlpMapResult a) { return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures()); } + + public double getThresholdMatch(List natures, QueryContext queryContext) { + Double threshold = optimizationConfig.getMetricDimensionThresholdConfig(); + Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig(); + if (mapperHelper.existDimensionValues(natures)) { + threshold = optimizationConfig.getDimensionValueThresholdConfig(); + minThreshold = optimizationConfig.getDimensionValueMinThresholdConfig(); + } + return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum()); + + } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/MapperHelper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/MapperHelper.java index 2717bc346..36c686bda 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/MapperHelper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/MapperHelper.java @@ -45,13 +45,6 @@ public class MapperHelper { return index; } - public double getThresholdMatch(List natures) { - if (existDimensionValues(natures)) { - return optimizationConfig.getDimensionValueThresholdConfig(); - } - return optimizationConfig.getMetricDimensionThresholdConfig(); - } - /*** * exist dimension values * @param natures diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java index 2a78fbc4a..9ce610b16 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java @@ -26,6 +26,9 @@ public class OptimizationConfig { @Value("${metric.dimension.threshold:0.3}") private Double metricDimensionThresholdConfig; + @Value("${dimension.value.min.threshold:0.2}") + private Double dimensionValueMinThresholdConfig; + @Value("${dimension.value.threshold:0.5}") private Double dimensionValueThresholdConfig; @@ -52,6 +55,9 @@ public class OptimizationConfig { @Value("${embedding.mapper.round.number:10}") private int embeddingMapperRoundNumber; + @Value("${embedding.mapper.min.threshold:0.6}") + private Double embeddingMapperMinThreshold; + @Value("${embedding.mapper.threshold:0.99}") private Double embeddingMapperThreshold; @@ -95,6 +101,10 @@ public class OptimizationConfig { return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig); } + public Double getDimensionValueMinThresholdConfig() { + return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig); + } + public Double getDimensionValueThresholdConfig() { return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig); } @@ -135,6 +145,10 @@ public class OptimizationConfig { return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber); } + public Double getEmbeddingMapperMinThreshold() { + return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold); + } + public Double getEmbeddingMapperThreshold() { return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java index c151bfd77..58d442071 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java @@ -6,6 +6,7 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.core.chat.query.SemanticQuery; import com.tencent.supersonic.headless.core.config.OptimizationConfig; @@ -37,6 +38,7 @@ public class QueryContext { private QueryFilters queryFilters; private List candidateQueries = new ArrayList<>(); private SchemaMapInfo mapInfo = new SchemaMapInfo(); + private MapModeEnum mapModeEnum = MapModeEnum.STRICT; @JsonIgnore private SemanticSchema semanticSchema; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/MetaDiscoveryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/MetaDiscoveryApiController.java index 6dab347f6..01e944486 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/MetaDiscoveryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/MetaDiscoveryApiController.java @@ -15,7 +15,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @RestController -@RequestMapping("/api/semantic/query") +@RequestMapping("/api/semantic/meta") @Slf4j public class MetaDiscoveryApiController { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index 081ca3fcc..292815d66 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -179,6 +179,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { .mapInfo(new SchemaMapInfo()) .modelIdToDataSetIds(modelIdToDataSetIds) .text2SQLType(queryReq.getText2SQLType()) + .mapModeEnum(queryReq.getMapModeEnum()) .build(); BeanUtils.copyProperties(queryReq, queryCtx); return queryCtx;