(improvement)(chat) Remove irrelevant topN field information during the parsing of large models. (#1558)

This commit is contained in:
lexluo09
2024-08-12 17:31:42 +08:00
committed by GitHub
parent 0c70df12ca
commit c82c2d0a95
3 changed files with 7 additions and 65 deletions

View File

@@ -649,22 +649,6 @@ public class SqlSelectHelper {
return withNameList;
}
public static Map<String, WithItem> getWith(String sql) {
Select selectStatement = getSelect(sql);
if (selectStatement == null) {
return new HashMap<>();
}
Map<String, WithItem> withMap = new HashMap<>();
List<WithItem> withItemList = selectStatement.getWithItemsList();
if (!CollectionUtils.isEmpty(withItemList)) {
for (int i = 0; i < withItemList.size(); i++) {
WithItem withItem = withItemList.get(i);
withMap.put(withItem.getAlias().getName(), withItem);
}
}
return withMap;
}
public static Table getTable(String sql) {
Select selectStatement = getSelect(sql);
if (selectStatement == null) {

View File

@@ -11,12 +11,6 @@ public class LLMParserConfig {
@Value("${s2.recall.max.retries:3}")
private int recallMaxRetries;
@Value("${s2.dimension.topn:10}")
private Integer dimensionTopN;
@Value("${s2.metric.topn:10}")
private Integer metricTopN;
@Value("${s2.tag.topn:20}")
private Integer tagTopN;

View File

@@ -23,7 +23,6 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -37,10 +36,6 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_ST
@Slf4j
@Service
public class LLMRequestService {
@Autowired
private LLMParserConfig llmParserConfig;
@Autowired
private ParserConfig parserConfig;
@@ -81,7 +76,7 @@ public class LLMRequestService {
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
List<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
if (Objects.nonNull(semanticSchema.getDataSetSchemaMap())
&& Objects.nonNull(semanticSchema.getDataSetSchemaMap().get(dataSetId))) {
TimeDefaultConfig timeDefaultConfig = semanticSchema.getDataSetSchemaMap()
@@ -128,17 +123,6 @@ public class LLMRequestService {
return result;
}
protected List<String> getFieldNameList(ChatQueryContext queryCtx, Long dataSetId,
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
results.addAll(fieldNameList);
return new ArrayList<>(results);
}
protected List<LLMReq.Term> getTerms(ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
@@ -215,24 +199,6 @@ public class LLMRequestService {
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private Set<String> getTopNFieldNames(ChatQueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Set<String> results = new HashSet<>();
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(dimensions);
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
return results;
}
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
@@ -255,23 +221,20 @@ public class LLMRequestService {
if (CollectionUtils.isEmpty(matchedElements)) {
return Collections.emptyList();
}
List<SchemaElement> schemaElements = matchedElements.stream()
return matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.DIMENSION.equals(elementType);
})
.map(schemaElementMatch -> {
return schemaElementMatch.getElement();
})
.map(schemaElementMatch -> schemaElementMatch.getElement())
.collect(Collectors.toList());
return schemaElements;
}
protected Set<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
protected List<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new HashSet<>();
return new ArrayList<>();
}
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
@@ -289,6 +252,7 @@ public class LLMRequestService {
return schemaElementMatch.getWord();
})
.collect(Collectors.toSet());
return fieldNameList;
return new ArrayList<>(fieldNameList);
}
}