(improvement):use model match score to select model (#268)

This commit is contained in:
Scott
2023-10-20 22:43:41 +08:00
committed by GitHub
parent d710986923
commit 8bd43f113b

View File

@@ -23,10 +23,10 @@ import org.apache.commons.collections.CollectionUtils;
@Slf4j @Slf4j
public class HeuristicModelResolver implements ModelResolver { public class HeuristicModelResolver implements ModelResolver {
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes, protected static Long selectModelBySchemaElementMatchScore(Map<Long, SemanticQuery> modelQueryModes,
SchemaMapInfo schemaMap) { SchemaMapInfo schemaMap) {
//model count priority //model count priority
Long modelIdByModelCount = getModelIdByModelCount(schemaMap); Long modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
if (Objects.nonNull(modelIdByModelCount)) { if (Objects.nonNull(modelIdByModelCount)) {
log.info("selectModel by model count:{}", modelIdByModelCount); log.info("selectModel by model count:{}", modelIdByModelCount);
return modelIdByModelCount; return modelIdByModelCount;
@@ -59,26 +59,28 @@ public class HeuristicModelResolver implements ModelResolver {
return 0L; return 0L;
} }
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) { private static Long getModelIdByMatchModelScore(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches(); Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
Map<Long, Integer> modelIdToModelCount = new HashMap<>(); // calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
Map<Long, Double> modelIdToModelScore = new HashMap<>();
if (Objects.nonNull(modelElementMatches)) { if (Objects.nonNull(modelElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) { for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
Long modelId = modelElementMatch.getKey(); Long modelId = modelElementMatch.getKey();
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter( List<Double> modelMatchesScore = modelElementMatch.getValue().stream().filter(
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()) elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
).collect(Collectors.toList()); .map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(modelMatches)) { if (!CollectionUtils.isEmpty(modelMatchesScore)) {
Integer count = modelMatches.size(); // get sum of model match score
modelIdToModelCount.put(modelId, count); double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
modelIdToModelScore.put(modelId, score);
} }
} }
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream() Entry<Long, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null); .max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount); log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
if (Objects.nonNull(maxModelCount)) { if (Objects.nonNull(maxModelScore)) {
return maxModelCount.getKey(); return maxModelScore.getKey();
} }
} }
return null; return null;
@@ -205,8 +207,8 @@ public class HeuristicModelResolver implements ModelResolver {
log.info("selectModel {} ", selectModel); log.info("selectModel {} ", selectModel);
return selectModel; return selectModel;
} }
// get the max SchemaElementType number // get the max SchemaElementType match score
return selectModelBySchemaElementCount(modelQueryModes, schemaMap); return selectModelBySchemaElementMatchScore(modelQueryModes, schemaMap);
} }
public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext, public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,