mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:18:23 +00:00
(improvement):use model match score to select model (#268)
This commit is contained in:
@@ -23,10 +23,10 @@ import org.apache.commons.collections.CollectionUtils;
|
||||
@Slf4j
|
||||
public class HeuristicModelResolver implements ModelResolver {
|
||||
|
||||
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
|
||||
protected static Long selectModelBySchemaElementMatchScore(Map<Long, SemanticQuery> modelQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//model count priority
|
||||
Long modelIdByModelCount = getModelIdByModelCount(schemaMap);
|
||||
Long modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
|
||||
if (Objects.nonNull(modelIdByModelCount)) {
|
||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||
return modelIdByModelCount;
|
||||
@@ -59,26 +59,28 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
return 0L;
|
||||
}
|
||||
|
||||
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) {
|
||||
private static Long getModelIdByMatchModelScore(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
Map<Long, Integer> modelIdToModelCount = new HashMap<>();
|
||||
// calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<Long, Double> modelIdToModelScore = new HashMap<>();
|
||||
if (Objects.nonNull(modelElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||
Long modelId = modelElementMatch.getKey();
|
||||
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter(
|
||||
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())
|
||||
).collect(Collectors.toList());
|
||||
List<Double> modelMatchesScore = modelElementMatch.getValue().stream().filter(
|
||||
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(modelMatches)) {
|
||||
Integer count = modelMatches.size();
|
||||
modelIdToModelCount.put(modelId, count);
|
||||
if (!CollectionUtils.isEmpty(modelMatchesScore)) {
|
||||
// get sum of model match score
|
||||
double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
modelIdToModelScore.put(modelId, score);
|
||||
}
|
||||
}
|
||||
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream()
|
||||
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null);
|
||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount);
|
||||
if (Objects.nonNull(maxModelCount)) {
|
||||
return maxModelCount.getKey();
|
||||
Entry<Long, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
|
||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
|
||||
if (Objects.nonNull(maxModelScore)) {
|
||||
return maxModelScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
@@ -205,8 +207,8 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
log.info("selectModel {} ", selectModel);
|
||||
return selectModel;
|
||||
}
|
||||
// get the max SchemaElementType number
|
||||
return selectModelBySchemaElementCount(modelQueryModes, schemaMap);
|
||||
// get the max SchemaElementType match score
|
||||
return selectModelBySchemaElementMatchScore(modelQueryModes, schemaMap);
|
||||
}
|
||||
|
||||
public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,
|
||||
|
||||
Reference in New Issue
Block a user