diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/HeuristicModelResolver.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/HeuristicModelResolver.java index f3c01a0a4..a342398c9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/HeuristicModelResolver.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/HeuristicModelResolver.java @@ -1,23 +1,23 @@ package com.tencent.supersonic.chat.parser.plugin.function; -import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.ChatContext; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; -import com.tencent.supersonic.chat.api.component.SemanticQuery; -import lombok.extern.slf4j.Slf4j; - -import java.util.Map; +import java.util.Comparator; import java.util.HashMap; -import java.util.Objects; -import java.util.List; import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; - +import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; @Slf4j @@ -25,6 +25,13 @@ public class HeuristicModelResolver implements ModelResolver { protected static Long selectModelBySchemaElementCount(Map modelQueryModes, SchemaMapInfo schemaMap) { + //model count priority + Long modelIdByModelCount = getModelIdByModelCount(schemaMap); + if (Objects.nonNull(modelIdByModelCount)) { + log.info("selectModel by model count:{}", modelIdByModelCount); + return modelIdByModelCount; + } + Map modelTypeMap = getModelTypeMap(schemaMap); if (modelTypeMap.size() == 1) { Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey(); @@ -33,6 +40,7 @@ public class HeuristicModelResolver implements ModelResolver { return modelSelect; } } else { + Map.Entry maxModel = modelTypeMap.entrySet().stream() .filter(entry -> modelQueryModes.containsKey(entry.getKey())) .sorted((o1, o2) -> { @@ -51,6 +59,31 @@ public class HeuristicModelResolver implements ModelResolver { return 0L; } + private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) { + Map> modelElementMatches = schemaMap.getModelElementMatches(); + Map modelIdToModelCount = new HashMap<>(); + if (Objects.nonNull(modelElementMatches)) { + for (Entry> modelElementMatch : modelElementMatches.entrySet()) { + Long modelId = modelElementMatch.getKey(); + List modelMatches = modelElementMatch.getValue().stream().filter( + elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()) + ).collect(Collectors.toList()); + + if (!CollectionUtils.isEmpty(modelMatches)) { + Integer count = modelMatches.size(); + modelIdToModelCount.put(modelId, count); + } + } + Entry maxModelCount = modelIdToModelCount.entrySet().stream() + .max(Comparator.comparingInt(o -> o.getValue())).orElse(null); + log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount); + if (Objects.nonNull(maxModelCount)) { + return maxModelCount.getKey(); + } + } + return null; + } + /** * to check can switch Model if context exit Model * diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/WordBuilderFactory.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/WordBuilderFactory.java index 674928689..cc6b6fdf8 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/WordBuilderFactory.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/WordBuilderFactory.java @@ -15,7 +15,7 @@ public class WordBuilderFactory { static { wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder()); wordNatures.put(DictWordType.METRIC, new MetricWordBuilder()); - wordNatures.put(DictWordType.DOMAIN, new ModelWordBuilder()); + wordNatures.put(DictWordType.MODEL, new ModelWordBuilder()); wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder()); wordNatures.put(DictWordType.VALUE, new ValueWordBuilder()); } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/WordService.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/WordService.java index 86ec65b73..1aaece555 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/WordService.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/WordService.java @@ -29,7 +29,7 @@ public class WordService { addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words); addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words); - addWordsByType(DictWordType.DOMAIN, semanticSchema.getModels(), words); + addWordsByType(DictWordType.MODEL, semanticSchema.getModels(), words); addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words); addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words); diff --git a/chat/knowledge/src/test/java/com/tencent/supersonic/knowledge/utils/NatureHelperTest.java b/chat/knowledge/src/test/java/com/tencent/supersonic/knowledge/utils/NatureHelperTest.java new file mode 100644 index 000000000..2ae155177 --- /dev/null +++ b/chat/knowledge/src/test/java/com/tencent/supersonic/knowledge/utils/NatureHelperTest.java @@ -0,0 +1,16 @@ +package com.tencent.supersonic.knowledge.utils; + + +import cn.hutool.core.lang.Assert; +import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import org.junit.jupiter.api.Test; + +class NatureHelperTest { + + @Test + void convertToElementType() { + SchemaElementType schemaElementType = NatureHelper.convertToElementType("_1"); + + Assert.equals(schemaElementType, SchemaElementType.MODEL); + } +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java index 66b3a0d13..7322555b7 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java @@ -8,11 +8,13 @@ import org.apache.commons.lang3.StringUtils; */ public enum DictWordType { METRIC("metric"), + DIMENSION("dimension"), + VALUE("value"), - DOMAIN("dm"), MODEL("model"), + ENTITY("entity"), NUMBER("m"), @@ -44,7 +46,7 @@ public enum DictWordType { //domain String[] natures = nature.split(DictWordType.NATURE_SPILT); if (natures.length == 2 && StringUtils.isNumeric(natures[1])) { - return DOMAIN; + return MODEL; } //dimension value if (natures.length == 3 && StringUtils.isNumeric(natures[1]) && StringUtils.isNumeric(natures[2])) {