mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Remove irrelevant topN field information during the parsing of large models. (#1558)
This commit is contained in:
@@ -649,22 +649,6 @@ public class SqlSelectHelper {
|
||||
return withNameList;
|
||||
}
|
||||
|
||||
public static Map<String, WithItem> getWith(String sql) {
|
||||
Select selectStatement = getSelect(sql);
|
||||
if (selectStatement == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
Map<String, WithItem> withMap = new HashMap<>();
|
||||
List<WithItem> withItemList = selectStatement.getWithItemsList();
|
||||
if (!CollectionUtils.isEmpty(withItemList)) {
|
||||
for (int i = 0; i < withItemList.size(); i++) {
|
||||
WithItem withItem = withItemList.get(i);
|
||||
withMap.put(withItem.getAlias().getName(), withItem);
|
||||
}
|
||||
}
|
||||
return withMap;
|
||||
}
|
||||
|
||||
public static Table getTable(String sql) {
|
||||
Select selectStatement = getSelect(sql);
|
||||
if (selectStatement == null) {
|
||||
|
||||
@@ -11,12 +11,6 @@ public class LLMParserConfig {
|
||||
@Value("${s2.recall.max.retries:3}")
|
||||
private int recallMaxRetries;
|
||||
|
||||
@Value("${s2.dimension.topn:10}")
|
||||
private Integer dimensionTopN;
|
||||
|
||||
@Value("${s2.metric.topn:10}")
|
||||
private Integer metricTopN;
|
||||
|
||||
@Value("${s2.tag.topn:20}")
|
||||
private Integer tagTopN;
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -37,10 +36,6 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_ST
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMRequestService {
|
||||
|
||||
@Autowired
|
||||
private LLMParserConfig llmParserConfig;
|
||||
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
@@ -81,7 +76,7 @@ public class LLMRequestService {
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
|
||||
List<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
|
||||
if (Objects.nonNull(semanticSchema.getDataSetSchemaMap())
|
||||
&& Objects.nonNull(semanticSchema.getDataSetSchemaMap().get(dataSetId))) {
|
||||
TimeDefaultConfig timeDefaultConfig = semanticSchema.getDataSetSchemaMap()
|
||||
@@ -128,17 +123,6 @@ public class LLMRequestService {
|
||||
return result;
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(ChatQueryContext queryCtx, Long dataSetId,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
|
||||
|
||||
results.addAll(fieldNameList);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected List<LLMReq.Term> getTerms(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
@@ -215,24 +199,6 @@ public class LLMRequestService {
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(ChatQueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Set<String> results = new HashSet<>();
|
||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(dimensions);
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(metrics);
|
||||
return results;
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
@@ -255,23 +221,20 @@ public class LLMRequestService {
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<SchemaElement> schemaElements = matchedElements.stream()
|
||||
return matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.DIMENSION.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
return schemaElementMatch.getElement();
|
||||
})
|
||||
.map(schemaElementMatch -> schemaElementMatch.getElement())
|
||||
.collect(Collectors.toList());
|
||||
return schemaElements;
|
||||
}
|
||||
|
||||
protected Set<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
protected List<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new HashSet<>();
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Set<String> fieldNameList = matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
@@ -289,6 +252,7 @@ public class LLMRequestService {
|
||||
return schemaElementMatch.getWord();
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
return fieldNameList;
|
||||
|
||||
return new ArrayList<>(fieldNameList);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user