mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) remove domain in wordservice and the model id takes precedence over the number of model aliases (#87)
This commit is contained in:
@@ -1,23 +1,23 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
import java.util.Comparator;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -25,6 +25,13 @@ public class HeuristicModelResolver implements ModelResolver {
|
|||||||
|
|
||||||
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
|
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
|
||||||
SchemaMapInfo schemaMap) {
|
SchemaMapInfo schemaMap) {
|
||||||
|
//model count priority
|
||||||
|
Long modelIdByModelCount = getModelIdByModelCount(schemaMap);
|
||||||
|
if (Objects.nonNull(modelIdByModelCount)) {
|
||||||
|
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||||
|
return modelIdByModelCount;
|
||||||
|
}
|
||||||
|
|
||||||
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
||||||
if (modelTypeMap.size() == 1) {
|
if (modelTypeMap.size() == 1) {
|
||||||
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
||||||
@@ -33,6 +40,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
|||||||
return modelSelect;
|
return modelSelect;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
||||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
||||||
.sorted((o1, o2) -> {
|
.sorted((o1, o2) -> {
|
||||||
@@ -51,6 +59,31 @@ public class HeuristicModelResolver implements ModelResolver {
|
|||||||
return 0L;
|
return 0L;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) {
|
||||||
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||||
|
Map<Long, Integer> modelIdToModelCount = new HashMap<>();
|
||||||
|
if (Objects.nonNull(modelElementMatches)) {
|
||||||
|
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||||
|
Long modelId = modelElementMatch.getKey();
|
||||||
|
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter(
|
||||||
|
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())
|
||||||
|
).collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (!CollectionUtils.isEmpty(modelMatches)) {
|
||||||
|
Integer count = modelMatches.size();
|
||||||
|
modelIdToModelCount.put(modelId, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream()
|
||||||
|
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null);
|
||||||
|
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount);
|
||||||
|
if (Objects.nonNull(maxModelCount)) {
|
||||||
|
return maxModelCount.getKey();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* to check can switch Model if context exit Model
|
* to check can switch Model if context exit Model
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ public class WordBuilderFactory {
|
|||||||
static {
|
static {
|
||||||
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
|
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
|
||||||
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
|
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
|
||||||
wordNatures.put(DictWordType.DOMAIN, new ModelWordBuilder());
|
wordNatures.put(DictWordType.MODEL, new ModelWordBuilder());
|
||||||
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
|
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
|
||||||
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
|
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ public class WordService {
|
|||||||
|
|
||||||
addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words);
|
addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words);
|
||||||
addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words);
|
addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words);
|
||||||
addWordsByType(DictWordType.DOMAIN, semanticSchema.getModels(), words);
|
addWordsByType(DictWordType.MODEL, semanticSchema.getModels(), words);
|
||||||
addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words);
|
addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words);
|
||||||
addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words);
|
addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words);
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.tencent.supersonic.knowledge.utils;
|
||||||
|
|
||||||
|
|
||||||
|
import cn.hutool.core.lang.Assert;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
class NatureHelperTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void convertToElementType() {
|
||||||
|
SchemaElementType schemaElementType = NatureHelper.convertToElementType("_1");
|
||||||
|
|
||||||
|
Assert.equals(schemaElementType, SchemaElementType.MODEL);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,11 +8,13 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
*/
|
*/
|
||||||
public enum DictWordType {
|
public enum DictWordType {
|
||||||
METRIC("metric"),
|
METRIC("metric"),
|
||||||
|
|
||||||
DIMENSION("dimension"),
|
DIMENSION("dimension"),
|
||||||
|
|
||||||
VALUE("value"),
|
VALUE("value"),
|
||||||
|
|
||||||
DOMAIN("dm"),
|
|
||||||
MODEL("model"),
|
MODEL("model"),
|
||||||
|
|
||||||
ENTITY("entity"),
|
ENTITY("entity"),
|
||||||
|
|
||||||
NUMBER("m"),
|
NUMBER("m"),
|
||||||
@@ -44,7 +46,7 @@ public enum DictWordType {
|
|||||||
//domain
|
//domain
|
||||||
String[] natures = nature.split(DictWordType.NATURE_SPILT);
|
String[] natures = nature.split(DictWordType.NATURE_SPILT);
|
||||||
if (natures.length == 2 && StringUtils.isNumeric(natures[1])) {
|
if (natures.length == 2 && StringUtils.isNumeric(natures[1])) {
|
||||||
return DOMAIN;
|
return MODEL;
|
||||||
}
|
}
|
||||||
//dimension value
|
//dimension value
|
||||||
if (natures.length == 3 && StringUtils.isNumeric(natures[1]) && StringUtils.isNumeric(natures[2])) {
|
if (natures.length == 3 && StringUtils.isNumeric(natures[1]) && StringUtils.isNumeric(natures[2])) {
|
||||||
|
|||||||
Reference in New Issue
Block a user