(improvement)(Chat) Fixed the issue of ineffective filtering in mapper detectDataSetIds, resolved the autocomplete feature, and changed METRIC_TAG to METRIC_ID. (#819)

This commit is contained in:
lexluo09
2024-03-14 16:58:41 +08:00
committed by GitHub
parent 901770f02c
commit 30ee64efec
25 changed files with 148 additions and 168 deletions

View File

@@ -56,21 +56,24 @@ public class KnowledgeService {
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
}
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds) {
return prefixSearchByModel(key, limit, modelIdToDataSetIds);
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds) {
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds);
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds) {
return suffixSearchByModel(key, limit, modelIdToDataSetIds.keySet());
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Set<Long> models) {
return SearchService.suffixSearch(key, limit, models);
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
}

View File

@@ -7,12 +7,7 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
@@ -21,6 +16,11 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
@@ -31,9 +31,9 @@ public class MetaEmbeddingService {
private EmbeddingConfig embeddingConfig;
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
Map<Long, List<Long>> modelIdToDataSetIds) {
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
// dataSetIds->modelIds
Set<Long> allModels = modelIdToDataSetIds.keySet();
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
Map<String, String> filterCondition = new HashMap<>();

View File

@@ -41,13 +41,15 @@ public class SearchService {
* @param key
* @return
*/
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToViewIds) {
return prefixSearch(key, limit, trie, modelIdToViewIds);
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
return prefixSearch(key, limit, trie, modelIdToDataSetIds, detectDataSetIds);
}
public static List<HanlpMapResult> prefixSearch(String key, int limit, BinTrie<List<String>> binTrie,
Map<Long, List<Long>> modelIdToViewIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToViewIds.keySet());
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie,
modelIdToDataSetIds, detectDataSetIds);
List<HanlpMapResult> hanlpMapResults = result.stream().map(
entry -> {
String name = entry.getKey().replace("#", " ");
@@ -58,7 +60,7 @@ public class SearchService {
.collect(Collectors.toList());
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
List<String> natures = hanlpMapResult.getNatures().stream()
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToViewIds))
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds))
.flatMap(Collection::stream).collect(Collectors.toList());
hanlpMapResult.setNatures(natures);
}
@@ -70,14 +72,18 @@ public class SearchService {
* @param key
* @return
*/
public static List<HanlpMapResult> suffixSearch(String key, int limit, Set<Long> detectModelIds) {
public static List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
String reverseDetectSegment = StringUtils.reverse(key);
return suffixSearch(reverseDetectSegment, limit, suffixTrie, detectModelIds);
return suffixSearch(reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds);
}
public static List<HanlpMapResult> suffixSearch(String key, int limit, BinTrie<List<String>> binTrie,
Set<Long> detectModelIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, detectModelIds);
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToDataSetIds,
detectDataSetIds);
return result.stream().map(
entry -> {
String name = entry.getKey().replace("#", " ");
@@ -93,7 +99,10 @@ public class SearchService {
}
private static Set<Map.Entry<String, List<String>>> prefixSearchLimit(String key, int limit,
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Long> detectModelIds = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
key = key.toLowerCase();
Set<Map.Entry<String, List<String>>> entrySet = new TreeSet<Map.Entry<String, List<String>>>();

View File

@@ -6,17 +6,17 @@ import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.chat.knowledge.DataSetInfoStat;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
/**
* nature parse helper
@@ -220,4 +220,18 @@ public class NatureHelper {
return 0L;
}
public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Long> detectModelIds = modelIdToDataSetIds.keySet();
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
detectModelIds = modelIdToDataSetIds.entrySet().stream().filter(entry -> {
List<Long> dataSetIds = entry.getValue().stream().filter(detectDataSetIds::contains)
.collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dataSetIds)) {
return true;
}
return false;
}).map(entry -> entry.getKey()).collect(Collectors.toSet());
}
return detectModelIds;
}
}

View File

@@ -49,7 +49,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
String detectSegment, int offset) {
}
@@ -77,10 +77,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
// step2. retrieveQuery by detectSegment
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
retrieveQuery, embeddingNumber, modelIdToDataSetIds);
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;

View File

@@ -34,7 +34,9 @@ public class EntityMapper extends BaseMapper {
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.TAG_VALUE.equals(schemaElementMatch.getElement().getType()
))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {

View File

@@ -39,7 +39,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
@@ -65,11 +65,11 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds())
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds())
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults.addAll(suffixHanlpMapResults);

View File

@@ -33,7 +33,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
Set<Long> detectDataSetIds) {
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
@@ -58,9 +58,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds());
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds());
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricIdQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.AbstractMap;
@@ -94,7 +94,7 @@ public class ContextInheritParser implements SemanticParser {
return matches.stream().anyMatch(m -> {
SchemaElementType type = m.getElement().getType();
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
&& !(ruleQuery instanceof MetricTagQuery)) {
&& !(ruleQuery instanceof MetricIdQuery)) {
return types.contains(type);
}
return type.equals(matchType);

View File

@@ -1,31 +1,31 @@
package com.tencent.supersonic.headless.core.chat.query.rule.metric;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Slf4j
@Component
public class MetricTagQuery extends MetricSemanticQuery {
public class MetricIdQuery extends MetricSemanticQuery {
public static final String QUERY_MODE = "METRIC_TAG";
public static final String QUERY_MODE = "METRIC_ID";
public MetricTagQuery() {
public MetricIdQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1)
.addOption(ENTITY, REQUIRED, AT_LEAST, 1);