(improvement)(chat) remove domain in wordservice and the model id takes precedence over the number of model aliases (#87)

This commit is contained in:
lexluo09
2023-09-13 21:25:31 +08:00
committed by GitHub
parent c802c508fb
commit 6a98ce9d28
5 changed files with 65 additions and 14 deletions

View File

@@ -1,23 +1,23 @@
package com.tencent.supersonic.chat.parser.plugin.function; package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.component.SemanticQuery; import java.util.Comparator;
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
import java.util.HashMap; import java.util.HashMap;
import java.util.Objects;
import java.util.List;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@Slf4j @Slf4j
@@ -25,6 +25,13 @@ public class HeuristicModelResolver implements ModelResolver {
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes, protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
SchemaMapInfo schemaMap) { SchemaMapInfo schemaMap) {
//model count priority
Long modelIdByModelCount = getModelIdByModelCount(schemaMap);
if (Objects.nonNull(modelIdByModelCount)) {
log.info("selectModel by model count:{}", modelIdByModelCount);
return modelIdByModelCount;
}
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap); Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
if (modelTypeMap.size() == 1) { if (modelTypeMap.size() == 1) {
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey(); Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
@@ -33,6 +40,7 @@ public class HeuristicModelResolver implements ModelResolver {
return modelSelect; return modelSelect;
} }
} else { } else {
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream() Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
.filter(entry -> modelQueryModes.containsKey(entry.getKey())) .filter(entry -> modelQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> { .sorted((o1, o2) -> {
@@ -51,6 +59,31 @@ public class HeuristicModelResolver implements ModelResolver {
return 0L; return 0L;
} }
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
Map<Long, Integer> modelIdToModelCount = new HashMap<>();
if (Objects.nonNull(modelElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
Long modelId = modelElementMatch.getKey();
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter(
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())
).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(modelMatches)) {
Integer count = modelMatches.size();
modelIdToModelCount.put(modelId, count);
}
}
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream()
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null);
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount);
if (Objects.nonNull(maxModelCount)) {
return maxModelCount.getKey();
}
}
return null;
}
/** /**
* to check can switch Model if context exit Model * to check can switch Model if context exit Model
* *

View File

@@ -15,7 +15,7 @@ public class WordBuilderFactory {
static { static {
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder()); wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder()); wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
wordNatures.put(DictWordType.DOMAIN, new ModelWordBuilder()); wordNatures.put(DictWordType.MODEL, new ModelWordBuilder());
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder()); wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder()); wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
} }

View File

@@ -29,7 +29,7 @@ public class WordService {
addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words); addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words);
addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words); addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words);
addWordsByType(DictWordType.DOMAIN, semanticSchema.getModels(), words); addWordsByType(DictWordType.MODEL, semanticSchema.getModels(), words);
addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words); addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words);
addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words); addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words);

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.knowledge.utils;
import cn.hutool.core.lang.Assert;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import org.junit.jupiter.api.Test;
class NatureHelperTest {
@Test
void convertToElementType() {
SchemaElementType schemaElementType = NatureHelper.convertToElementType("_1");
Assert.equals(schemaElementType, SchemaElementType.MODEL);
}
}

View File

@@ -8,11 +8,13 @@ import org.apache.commons.lang3.StringUtils;
*/ */
public enum DictWordType { public enum DictWordType {
METRIC("metric"), METRIC("metric"),
DIMENSION("dimension"), DIMENSION("dimension"),
VALUE("value"), VALUE("value"),
DOMAIN("dm"),
MODEL("model"), MODEL("model"),
ENTITY("entity"), ENTITY("entity"),
NUMBER("m"), NUMBER("m"),
@@ -44,7 +46,7 @@ public enum DictWordType {
//domain //domain
String[] natures = nature.split(DictWordType.NATURE_SPILT); String[] natures = nature.split(DictWordType.NATURE_SPILT);
if (natures.length == 2 && StringUtils.isNumeric(natures[1])) { if (natures.length == 2 && StringUtils.isNumeric(natures[1])) {
return DOMAIN; return MODEL;
} }
//dimension value //dimension value
if (natures.length == 3 && StringUtils.isNumeric(natures[1]) && StringUtils.isNumeric(natures[2])) { if (natures.length == 3 && StringUtils.isNumeric(natures[1]) && StringUtils.isNumeric(natures[2])) {