mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement):use model match score to select model (#268)
This commit is contained in:
@@ -23,10 +23,10 @@ import org.apache.commons.collections.CollectionUtils;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class HeuristicModelResolver implements ModelResolver {
|
public class HeuristicModelResolver implements ModelResolver {
|
||||||
|
|
||||||
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
|
protected static Long selectModelBySchemaElementMatchScore(Map<Long, SemanticQuery> modelQueryModes,
|
||||||
SchemaMapInfo schemaMap) {
|
SchemaMapInfo schemaMap) {
|
||||||
//model count priority
|
//model count priority
|
||||||
Long modelIdByModelCount = getModelIdByModelCount(schemaMap);
|
Long modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
|
||||||
if (Objects.nonNull(modelIdByModelCount)) {
|
if (Objects.nonNull(modelIdByModelCount)) {
|
||||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||||
return modelIdByModelCount;
|
return modelIdByModelCount;
|
||||||
@@ -59,26 +59,28 @@ public class HeuristicModelResolver implements ModelResolver {
|
|||||||
return 0L;
|
return 0L;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) {
|
private static Long getModelIdByMatchModelScore(SchemaMapInfo schemaMap) {
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||||
Map<Long, Integer> modelIdToModelCount = new HashMap<>();
|
// calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||||
|
Map<Long, Double> modelIdToModelScore = new HashMap<>();
|
||||||
if (Objects.nonNull(modelElementMatches)) {
|
if (Objects.nonNull(modelElementMatches)) {
|
||||||
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||||
Long modelId = modelElementMatch.getKey();
|
Long modelId = modelElementMatch.getKey();
|
||||||
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter(
|
List<Double> modelMatchesScore = modelElementMatch.getValue().stream().filter(
|
||||||
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())
|
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
|
||||||
).collect(Collectors.toList());
|
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||||
|
|
||||||
if (!CollectionUtils.isEmpty(modelMatches)) {
|
if (!CollectionUtils.isEmpty(modelMatchesScore)) {
|
||||||
Integer count = modelMatches.size();
|
// get sum of model match score
|
||||||
modelIdToModelCount.put(modelId, count);
|
double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||||
|
modelIdToModelScore.put(modelId, score);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream()
|
Entry<Long, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
|
||||||
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null);
|
.max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
|
||||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount);
|
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
|
||||||
if (Objects.nonNull(maxModelCount)) {
|
if (Objects.nonNull(maxModelScore)) {
|
||||||
return maxModelCount.getKey();
|
return maxModelScore.getKey();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
@@ -205,8 +207,8 @@ public class HeuristicModelResolver implements ModelResolver {
|
|||||||
log.info("selectModel {} ", selectModel);
|
log.info("selectModel {} ", selectModel);
|
||||||
return selectModel;
|
return selectModel;
|
||||||
}
|
}
|
||||||
// get the max SchemaElementType number
|
// get the max SchemaElementType match score
|
||||||
return selectModelBySchemaElementCount(modelQueryModes, schemaMap);
|
return selectModelBySchemaElementMatchScore(modelQueryModes, schemaMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,
|
public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,
|
||||||
|
|||||||
Reference in New Issue
Block a user