diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/OptimizationConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/OptimizationConfig.java index 180e2f0d6..a00480f78 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/OptimizationConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/OptimizationConfig.java @@ -76,6 +76,9 @@ public class OptimizationConfig { @Value("${text2sql.collection.name:text2dsl_agent_collection}") private String text2sqlCollectionName; + @Value("${parse.show.count:3}") + private Integer parseShowCount; + @Autowired private SysParameterService sysParameterService; @@ -147,6 +150,10 @@ public class OptimizationConfig { return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode); } + public Integer getParseShowCount() { + return convertValue("parse.show.count", Integer.class, parseShowCount); + } + public T convertValue(String paramName, Class targetType, T defaultValue) { try { String value = sysParameterService.getSysParameter().getParameterByName(paramName); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/pojo/QueryContext.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/pojo/QueryContext.java index cec1b2c6c..cd658ccf9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/pojo/QueryContext.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/pojo/QueryContext.java @@ -7,11 +7,15 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.core.agent.Agent; +import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.chat.core.plugin.Plugin; import com.tencent.supersonic.chat.core.query.SemanticQuery; +import com.tencent.supersonic.common.util.ContextUtils; import java.util.ArrayList; +import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -32,4 +36,15 @@ public class QueryContext { private Agent agent; private Map modelIdToChatRichConfig; private Map nameToPlugin; + + public List getCandidateQueries() { + OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); + Integer parseShowCount = optimizationConfig.getParseShowCount(); + candidateQueries = candidateQueries.stream() + .sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(), + Comparator.reverseOrder())) + .limit(parseShowCount) + .collect(Collectors.toList()); + return candidateQueries; + } } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java index 013b0078c..6cb15ca95 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java @@ -99,6 +99,11 @@ public class SysParameter { parameters.add(new Parameter("long.text.threshold", "0.8", "长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置")); + + //parse config + parameters.add(new Parameter("parse.show.count", "3", + "parseShowCount", "前端展示的解析个数", + "number", "Parser相关配置")); } }