(improvement)(chat) remove domain in wordservice and the model id takes precedence over the number of model aliases (#87)

This commit is contained in:
lexluo09
2023-09-13 21:25:31 +08:00
committed by GitHub
parent c802c508fb
commit 6a98ce9d28
5 changed files with 65 additions and 14 deletions

View File

@@ -1,23 +1,23 @@
package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Objects;
import java.util.List;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
@@ -25,6 +25,13 @@ public class HeuristicModelResolver implements ModelResolver {
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
SchemaMapInfo schemaMap) {
//model count priority
Long modelIdByModelCount = getModelIdByModelCount(schemaMap);
if (Objects.nonNull(modelIdByModelCount)) {
log.info("selectModel by model count:{}", modelIdByModelCount);
return modelIdByModelCount;
}
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
if (modelTypeMap.size() == 1) {
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
@@ -33,6 +40,7 @@ public class HeuristicModelResolver implements ModelResolver {
return modelSelect;
}
} else {
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
@@ -51,6 +59,31 @@ public class HeuristicModelResolver implements ModelResolver {
return 0L;
}
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
Map<Long, Integer> modelIdToModelCount = new HashMap<>();
if (Objects.nonNull(modelElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
Long modelId = modelElementMatch.getKey();
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter(
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())
).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(modelMatches)) {
Integer count = modelMatches.size();
modelIdToModelCount.put(modelId, count);
}
}
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream()
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null);
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount);
if (Objects.nonNull(maxModelCount)) {
return maxModelCount.getKey();
}
}
return null;
}
/**
* to check can switch Model if context exit Model
*