mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) remove domain in wordservice and the model id takes precedence over the number of model aliases (#87)
This commit is contained in:
@@ -1,23 +1,23 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.Objects;
|
||||
import java.util.List;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@@ -25,6 +25,13 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
|
||||
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//model count priority
|
||||
Long modelIdByModelCount = getModelIdByModelCount(schemaMap);
|
||||
if (Objects.nonNull(modelIdByModelCount)) {
|
||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||
return modelIdByModelCount;
|
||||
}
|
||||
|
||||
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
||||
if (modelTypeMap.size() == 1) {
|
||||
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
||||
@@ -33,6 +40,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
return modelSelect;
|
||||
}
|
||||
} else {
|
||||
|
||||
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
@@ -51,6 +59,31 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
return 0L;
|
||||
}
|
||||
|
||||
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
Map<Long, Integer> modelIdToModelCount = new HashMap<>();
|
||||
if (Objects.nonNull(modelElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||
Long modelId = modelElementMatch.getKey();
|
||||
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter(
|
||||
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())
|
||||
).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(modelMatches)) {
|
||||
Integer count = modelMatches.size();
|
||||
modelIdToModelCount.put(modelId, count);
|
||||
}
|
||||
}
|
||||
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream()
|
||||
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null);
|
||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount);
|
||||
if (Objects.nonNull(maxModelCount)) {
|
||||
return maxModelCount.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* to check can switch Model if context exit Model
|
||||
*
|
||||
|
||||
@@ -15,7 +15,7 @@ public class WordBuilderFactory {
|
||||
static {
|
||||
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
|
||||
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
|
||||
wordNatures.put(DictWordType.DOMAIN, new ModelWordBuilder());
|
||||
wordNatures.put(DictWordType.MODEL, new ModelWordBuilder());
|
||||
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
|
||||
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ public class WordService {
|
||||
|
||||
addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words);
|
||||
addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words);
|
||||
addWordsByType(DictWordType.DOMAIN, semanticSchema.getModels(), words);
|
||||
addWordsByType(DictWordType.MODEL, semanticSchema.getModels(), words);
|
||||
addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words);
|
||||
addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words);
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.knowledge.utils;
|
||||
|
||||
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class NatureHelperTest {
|
||||
|
||||
@Test
|
||||
void convertToElementType() {
|
||||
SchemaElementType schemaElementType = NatureHelper.convertToElementType("_1");
|
||||
|
||||
Assert.equals(schemaElementType, SchemaElementType.MODEL);
|
||||
}
|
||||
}
|
||||
@@ -8,11 +8,13 @@ import org.apache.commons.lang3.StringUtils;
|
||||
*/
|
||||
public enum DictWordType {
|
||||
METRIC("metric"),
|
||||
|
||||
DIMENSION("dimension"),
|
||||
|
||||
VALUE("value"),
|
||||
|
||||
DOMAIN("dm"),
|
||||
MODEL("model"),
|
||||
|
||||
ENTITY("entity"),
|
||||
|
||||
NUMBER("m"),
|
||||
@@ -44,7 +46,7 @@ public enum DictWordType {
|
||||
//domain
|
||||
String[] natures = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (natures.length == 2 && StringUtils.isNumeric(natures[1])) {
|
||||
return DOMAIN;
|
||||
return MODEL;
|
||||
}
|
||||
//dimension value
|
||||
if (natures.length == 3 && StringUtils.isNumeric(natures[1]) && StringUtils.isNumeric(natures[2])) {
|
||||
|
||||
Reference in New Issue
Block a user