From c82c2d0a9542b542826883dc5bfc1057f3e50b53 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Mon, 12 Aug 2024 17:31:42 +0800 Subject: [PATCH] (improvement)(chat) Remove irrelevant topN field information during the parsing of large models. (#1558) --- .../common/jsqlparser/SqlSelectHelper.java | 16 ------ .../chat/parser/llm/LLMParserConfig.java | 6 --- .../chat/parser/llm/LLMRequestService.java | 50 +++---------------- 3 files changed, 7 insertions(+), 65 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index ded4ea1d0..60e743975 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -649,22 +649,6 @@ public class SqlSelectHelper { return withNameList; } - public static Map getWith(String sql) { - Select selectStatement = getSelect(sql); - if (selectStatement == null) { - return new HashMap<>(); - } - Map withMap = new HashMap<>(); - List withItemList = selectStatement.getWithItemsList(); - if (!CollectionUtils.isEmpty(withItemList)) { - for (int i = 0; i < withItemList.size(); i++) { - WithItem withItem = withItemList.get(i); - withMap.put(withItem.getAlias().getName(), withItem); - } - } - return withMap; - } - public static Table getTable(String sql) { Select selectStatement = getSelect(sql); if (selectStatement == null) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMParserConfig.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMParserConfig.java index 08c4701fb..536b53dcf 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMParserConfig.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMParserConfig.java @@ -11,12 +11,6 @@ public class LLMParserConfig { @Value("${s2.recall.max.retries:3}") private int recallMaxRetries; - @Value("${s2.dimension.topn:10}") - private Integer dimensionTopN; - - @Value("${s2.metric.topn:10}") - private Integer metricTopN; - @Value("${s2.tag.topn:20}") private Integer tagTopN; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 760e43fd0..58b3f98ae 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -23,7 +23,6 @@ import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.Collections; -import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -37,10 +36,6 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_ST @Slf4j @Service public class LLMRequestService { - - @Autowired - private LLMParserConfig llmParserConfig; - @Autowired private ParserConfig parserConfig; @@ -81,7 +76,7 @@ public class LLMRequestService { llmSchema.setDataSetName(dataSetIdToName.get(dataSetId)); llmSchema.setDomainName(dataSetIdToName.get(dataSetId)); - List fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig); + List fieldNameList = getMatchedFieldNames(queryCtx, dataSetId); if (Objects.nonNull(semanticSchema.getDataSetSchemaMap()) && Objects.nonNull(semanticSchema.getDataSetSchemaMap().get(dataSetId))) { TimeDefaultConfig timeDefaultConfig = semanticSchema.getDataSetSchemaMap() @@ -128,17 +123,6 @@ public class LLMRequestService { return result; } - protected List getFieldNameList(ChatQueryContext queryCtx, Long dataSetId, - LLMParserConfig llmParserConfig) { - - Set results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig); - - Set fieldNameList = getMatchedFieldNames(queryCtx, dataSetId); - - results.addAll(fieldNameList); - return new ArrayList<>(results); - } - protected List getTerms(ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { @@ -215,24 +199,6 @@ public class LLMRequestService { .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); } - private Set getTopNFieldNames(ChatQueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) { - SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); - Set results = new HashSet<>(); - Set dimensions = semanticSchema.getDimensions(dataSetId).stream() - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getDimensionTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - results.addAll(dimensions); - Set metrics = semanticSchema.getMetrics(dataSetId).stream() - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getMetricTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - results.addAll(metrics); - return results; - } - protected List getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { @@ -255,23 +221,20 @@ public class LLMRequestService { if (CollectionUtils.isEmpty(matchedElements)) { return Collections.emptyList(); } - List schemaElements = matchedElements.stream() + return matchedElements.stream() .filter(schemaElementMatch -> { SchemaElementType elementType = schemaElementMatch.getElement().getType(); return SchemaElementType.DIMENSION.equals(elementType); }) - .map(schemaElementMatch -> { - return schemaElementMatch.getElement(); - }) + .map(schemaElementMatch -> schemaElementMatch.getElement()) .collect(Collectors.toList()); - return schemaElements; } - protected Set getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) { + protected List getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) { Map itemIdToName = getItemIdToName(queryCtx, dataSetId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { - return new HashSet<>(); + return new ArrayList<>(); } Set fieldNameList = matchedElements.stream() .filter(schemaElementMatch -> { @@ -289,6 +252,7 @@ public class LLMRequestService { return schemaElementMatch.getWord(); }) .collect(Collectors.toSet()); - return fieldNameList; + + return new ArrayList<>(fieldNameList); } }