mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(improvement)(Chat) Fixed the issue of ineffective filtering in mapper detectDataSetIds, resolved the autocomplete feature, and changed METRIC_TAG to METRIC_ID. (#819)
This commit is contained in:
@@ -56,21 +56,24 @@ public class KnowledgeService {
|
||||
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return prefixSearchByModel(key, limit, modelIdToDataSetIds);
|
||||
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds);
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return suffixSearchByModel(key, limit, modelIdToDataSetIds.keySet());
|
||||
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Set<Long> models) {
|
||||
return SearchService.suffixSearch(key, limit, models);
|
||||
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -7,12 +7,7 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
@@ -21,6 +16,11 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -31,9 +31,9 @@ public class MetaEmbeddingService {
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
// dataSetIds->modelIds
|
||||
Set<Long> allModels = modelIdToDataSetIds.keySet();
|
||||
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
|
||||
@@ -41,13 +41,15 @@ public class SearchService {
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToViewIds) {
|
||||
return prefixSearch(key, limit, trie, modelIdToViewIds);
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return prefixSearch(key, limit, trie, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, BinTrie<List<String>> binTrie,
|
||||
Map<Long, List<Long>> modelIdToViewIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToViewIds.keySet());
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie,
|
||||
modelIdToDataSetIds, detectDataSetIds);
|
||||
List<HanlpMapResult> hanlpMapResults = result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
@@ -58,7 +60,7 @@ public class SearchService {
|
||||
.collect(Collectors.toList());
|
||||
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
|
||||
List<String> natures = hanlpMapResult.getNatures().stream()
|
||||
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToViewIds))
|
||||
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds))
|
||||
.flatMap(Collection::stream).collect(Collectors.toList());
|
||||
hanlpMapResult.setNatures(natures);
|
||||
}
|
||||
@@ -70,14 +72,18 @@ public class SearchService {
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, Set<Long> detectModelIds) {
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String reverseDetectSegment = StringUtils.reverse(key);
|
||||
return suffixSearch(reverseDetectSegment, limit, suffixTrie, detectModelIds);
|
||||
return suffixSearch(reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, BinTrie<List<String>> binTrie,
|
||||
Set<Long> detectModelIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, detectModelIds);
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToDataSetIds,
|
||||
detectDataSetIds);
|
||||
|
||||
return result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
@@ -93,7 +99,10 @@ public class SearchService {
|
||||
}
|
||||
|
||||
private static Set<Map.Entry<String, List<String>>> prefixSearchLimit(String key, int limit,
|
||||
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
|
||||
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
|
||||
Set<Long> detectModelIds = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
key = key.toLowerCase();
|
||||
Set<Map.Entry<String, List<String>>> entrySet = new TreeSet<Map.Entry<String, List<String>>>();
|
||||
|
||||
|
||||
@@ -6,17 +6,17 @@ import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.DataSetInfoStat;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* nature parse helper
|
||||
@@ -220,4 +220,18 @@ public class NatureHelper {
|
||||
return 0L;
|
||||
}
|
||||
|
||||
public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectModelIds = modelIdToDataSetIds.keySet();
|
||||
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
detectModelIds = modelIdToDataSetIds.entrySet().stream().filter(entry -> {
|
||||
List<Long> dataSetIds = entry.getValue().stream().filter(detectDataSetIds::contains)
|
||||
.collect(Collectors.toList());
|
||||
if (!CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}).map(entry -> entry.getKey()).collect(Collectors.toSet());
|
||||
}
|
||||
return detectModelIds;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@@ -77,10 +77,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||
// step1. build query params
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds);
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
|
||||
@@ -34,7 +34,9 @@ public class EntityMapper extends BaseMapper {
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.TAG_VALUE.equals(schemaElementMatch.getElement().getType()
|
||||
))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
|
||||
@@ -39,7 +39,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -65,11 +65,11 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds())
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds())
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
@@ -33,7 +33,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
@@ -58,9 +58,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds());
|
||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds());
|
||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
|
||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricIdQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
@@ -94,7 +94,7 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return matches.stream().anyMatch(m -> {
|
||||
SchemaElementType type = m.getElement().getType();
|
||||
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
|
||||
&& !(ruleQuery instanceof MetricTagQuery)) {
|
||||
&& !(ruleQuery instanceof MetricIdQuery)) {
|
||||
return types.contains(type);
|
||||
}
|
||||
return type.equals(matchType);
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
package com.tencent.supersonic.headless.core.chat.query.rule.metric;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricTagQuery extends MetricSemanticQuery {
|
||||
public class MetricIdQuery extends MetricSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "METRIC_TAG";
|
||||
public static final String QUERY_MODE = "METRIC_ID";
|
||||
|
||||
public MetricTagQuery() {
|
||||
public MetricIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1)
|
||||
.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
Reference in New Issue
Block a user