diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/HeuristicModelResolver.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/HeuristicModelResolver.java index f93399cc2..b3a827f28 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/HeuristicModelResolver.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/HeuristicModelResolver.java @@ -23,10 +23,10 @@ import org.apache.commons.collections.CollectionUtils; @Slf4j public class HeuristicModelResolver implements ModelResolver { - protected static Long selectModelBySchemaElementCount(Map modelQueryModes, + protected static Long selectModelBySchemaElementMatchScore(Map modelQueryModes, SchemaMapInfo schemaMap) { //model count priority - Long modelIdByModelCount = getModelIdByModelCount(schemaMap); + Long modelIdByModelCount = getModelIdByMatchModelScore(schemaMap); if (Objects.nonNull(modelIdByModelCount)) { log.info("selectModel by model count:{}", modelIdByModelCount); return modelIdByModelCount; @@ -59,26 +59,28 @@ public class HeuristicModelResolver implements ModelResolver { return 0L; } - private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) { + private static Long getModelIdByMatchModelScore(SchemaMapInfo schemaMap) { Map> modelElementMatches = schemaMap.getModelElementMatches(); - Map modelIdToModelCount = new HashMap<>(); + // calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point + Map modelIdToModelScore = new HashMap<>(); if (Objects.nonNull(modelElementMatches)) { for (Entry> modelElementMatch : modelElementMatches.entrySet()) { Long modelId = modelElementMatch.getKey(); - List modelMatches = modelElementMatch.getValue().stream().filter( - elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()) - ).collect(Collectors.toList()); + List modelMatchesScore = modelElementMatch.getValue().stream().filter( + elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())) + .map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList()); - if (!CollectionUtils.isEmpty(modelMatches)) { - Integer count = modelMatches.size(); - modelIdToModelCount.put(modelId, count); + if (!CollectionUtils.isEmpty(modelMatchesScore)) { + // get sum of model match score + double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum(); + modelIdToModelScore.put(modelId, score); } } - Entry maxModelCount = modelIdToModelCount.entrySet().stream() - .max(Comparator.comparingInt(o -> o.getValue())).orElse(null); - log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount); - if (Objects.nonNull(maxModelCount)) { - return maxModelCount.getKey(); + Entry maxModelScore = modelIdToModelScore.entrySet().stream() + .max(Comparator.comparingDouble(o -> o.getValue())).orElse(null); + log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore); + if (Objects.nonNull(maxModelScore)) { + return maxModelScore.getKey(); } } return null; @@ -205,8 +207,8 @@ public class HeuristicModelResolver implements ModelResolver { log.info("selectModel {} ", selectModel); return selectModel; } - // get the max SchemaElementType number - return selectModelBySchemaElementCount(modelQueryModes, schemaMap); + // get the max SchemaElementType match score + return selectModelBySchemaElementMatchScore(modelQueryModes, schemaMap); } public Long selectModel(Map modelQueryModes, QueryReq queryContext,