diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/ParserConfig.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/ParserConfig.java index 6b25b6fc1..527d04a51 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/ParserConfig.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/ParserConfig.java @@ -49,9 +49,13 @@ public class ParserConfig extends ParameterConfig { public static final Parameter PARSER_SHOW_COUNT = new Parameter("s2.parser.show.count", "3", "解析结果展示个数", "前端展示的解析个数", "number", "语义解析配置"); + public static final Parameter PARSER_FIELDS_COUNT_THRESHOLD = + new Parameter("s2.parser.field.count.threshold", "3", "语义字段个数阈值", + "如果映射字段小于该阈值,则将数据集所有字段输入LLM", "number", "语义解析配置"); + @Override public List getSysParameters() { return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_FEW_SHOT_NUMBER, - PARSER_SELF_CONSISTENCY_NUMBER, PARSER_SHOW_COUNT); + PARSER_SELF_CONSISTENCY_NUMBER, PARSER_SHOW_COUNT, PARSER_FIELDS_COUNT_THRESHOLD); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 39cebf538..3de69f0cc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -24,8 +24,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE; -import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE; +import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*; @Slf4j @Service @@ -43,15 +42,23 @@ public class LLMRequestService { Map dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName(); String queryText = queryCtx.getRequest().getQueryText(); + LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema(); + int fieldCntThreshold = + Integer.valueOf(parserConfig.getParameterValue(PARSER_FIELDS_COUNT_THRESHOLD)); + if (queryCtx.getMapInfo().getMatchedElements(dataSetId).size() <= fieldCntThreshold) { + llmSchema.setMetrics(queryCtx.getSemanticSchema().getMetrics()); + llmSchema.setDimensions(queryCtx.getSemanticSchema().getDimensions()); + } else { + llmSchema.setMetrics(getMappedMetrics(queryCtx, dataSetId)); + llmSchema.setDimensions(getMappedDimensions(queryCtx, dataSetId)); + } + LLMReq llmReq = new LLMReq(); llmReq.setQueryText(queryText); - LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema(); llmReq.setSchema(llmSchema); llmSchema.setDatabaseType(getDatabaseType(queryCtx, dataSetId)); llmSchema.setDataSetId(dataSetId); llmSchema.setDataSetName(dataSetIdToName.get(dataSetId)); - llmSchema.setMetrics(getMappedMetrics(queryCtx, dataSetId)); - llmSchema.setDimensions(getMappedDimensions(queryCtx, dataSetId)); llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId)); llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));