mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
(improvement)(Chat) Fixed the issue of ineffective filtering in mapper detectDataSetIds, resolved the autocomplete feature, and changed METRIC_TAG to METRIC_ID. (#819)
This commit is contained in:
@@ -13,15 +13,15 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* MetricRecommendProcessor fills recommended metrics based on embedding similarity.
|
||||
@@ -49,7 +49,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||
MetaEmbeddingService metaEmbeddingService = ContextUtils.getBean(MetaEmbeddingService.class);
|
||||
List<RetrieveQueryResult> retrieveQueryResults =
|
||||
metaEmbeddingService.retrieveQuery(retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>());
|
||||
metaEmbeddingService.retrieveQuery(retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>(),
|
||||
new HashSet<>());
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -56,21 +56,24 @@ public class KnowledgeService {
|
||||
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return prefixSearchByModel(key, limit, modelIdToDataSetIds);
|
||||
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds);
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return suffixSearchByModel(key, limit, modelIdToDataSetIds.keySet());
|
||||
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Set<Long> models) {
|
||||
return SearchService.suffixSearch(key, limit, models);
|
||||
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -7,12 +7,7 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
@@ -21,6 +16,11 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -31,9 +31,9 @@ public class MetaEmbeddingService {
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
// dataSetIds->modelIds
|
||||
Set<Long> allModels = modelIdToDataSetIds.keySet();
|
||||
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
|
||||
@@ -41,13 +41,15 @@ public class SearchService {
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToViewIds) {
|
||||
return prefixSearch(key, limit, trie, modelIdToViewIds);
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return prefixSearch(key, limit, trie, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, BinTrie<List<String>> binTrie,
|
||||
Map<Long, List<Long>> modelIdToViewIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToViewIds.keySet());
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie,
|
||||
modelIdToDataSetIds, detectDataSetIds);
|
||||
List<HanlpMapResult> hanlpMapResults = result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
@@ -58,7 +60,7 @@ public class SearchService {
|
||||
.collect(Collectors.toList());
|
||||
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
|
||||
List<String> natures = hanlpMapResult.getNatures().stream()
|
||||
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToViewIds))
|
||||
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds))
|
||||
.flatMap(Collection::stream).collect(Collectors.toList());
|
||||
hanlpMapResult.setNatures(natures);
|
||||
}
|
||||
@@ -70,14 +72,18 @@ public class SearchService {
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, Set<Long> detectModelIds) {
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String reverseDetectSegment = StringUtils.reverse(key);
|
||||
return suffixSearch(reverseDetectSegment, limit, suffixTrie, detectModelIds);
|
||||
return suffixSearch(reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, BinTrie<List<String>> binTrie,
|
||||
Set<Long> detectModelIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, detectModelIds);
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToDataSetIds,
|
||||
detectDataSetIds);
|
||||
|
||||
return result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
@@ -93,7 +99,10 @@ public class SearchService {
|
||||
}
|
||||
|
||||
private static Set<Map.Entry<String, List<String>>> prefixSearchLimit(String key, int limit,
|
||||
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
|
||||
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
|
||||
Set<Long> detectModelIds = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
key = key.toLowerCase();
|
||||
Set<Map.Entry<String, List<String>>> entrySet = new TreeSet<Map.Entry<String, List<String>>>();
|
||||
|
||||
|
||||
@@ -6,17 +6,17 @@ import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.DataSetInfoStat;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* nature parse helper
|
||||
@@ -220,4 +220,18 @@ public class NatureHelper {
|
||||
return 0L;
|
||||
}
|
||||
|
||||
public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectModelIds = modelIdToDataSetIds.keySet();
|
||||
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
detectModelIds = modelIdToDataSetIds.entrySet().stream().filter(entry -> {
|
||||
List<Long> dataSetIds = entry.getValue().stream().filter(detectDataSetIds::contains)
|
||||
.collect(Collectors.toList());
|
||||
if (!CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}).map(entry -> entry.getKey()).collect(Collectors.toSet());
|
||||
}
|
||||
return detectModelIds;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@@ -77,10 +77,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||
// step1. build query params
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds);
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
|
||||
@@ -34,7 +34,9 @@ public class EntityMapper extends BaseMapper {
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.TAG_VALUE.equals(schemaElementMatch.getElement().getType()
|
||||
))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
|
||||
@@ -39,7 +39,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -65,11 +65,11 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds())
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds())
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
@@ -33,7 +33,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
@@ -58,9 +58,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds());
|
||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds());
|
||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
|
||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricIdQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
@@ -94,7 +94,7 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return matches.stream().anyMatch(m -> {
|
||||
SchemaElementType type = m.getElement().getType();
|
||||
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
|
||||
&& !(ruleQuery instanceof MetricTagQuery)) {
|
||||
&& !(ruleQuery instanceof MetricIdQuery)) {
|
||||
return types.contains(type);
|
||||
}
|
||||
return type.equals(matchType);
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
package com.tencent.supersonic.headless.core.chat.query.rule.metric;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricTagQuery extends MetricSemanticQuery {
|
||||
public class MetricIdQuery extends MetricSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "METRIC_TAG";
|
||||
public static final String QUERY_MODE = "METRIC_ID";
|
||||
|
||||
public MetricTagQuery() {
|
||||
public MetricIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1)
|
||||
.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
@@ -193,7 +193,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
|
||||
SemanticParseInfo parseInfo, User user) throws Exception {
|
||||
SemanticParseInfo parseInfo, User user) throws Exception {
|
||||
SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
if (queryResp != null) {
|
||||
@@ -591,7 +591,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
|
||||
//search from prefixSearch
|
||||
List<HanlpMapResult> hanlpMapResultList = knowledgeService.prefixSearch(dimensionValueReq.getValue(),
|
||||
2000, modelIdToDataSetIds);
|
||||
2000, modelIdToDataSetIds, dataSetIds);
|
||||
HanlpHelper.transLetterOriginal(hanlpMapResultList);
|
||||
return hanlpMapResultList.stream()
|
||||
.filter(o -> {
|
||||
|
||||
@@ -78,8 +78,11 @@ public class SearchServiceImpl implements SearchService {
|
||||
|
||||
QueryContext queryContext = new QueryContext();
|
||||
BeanUtils.copyProperties(queryReq, queryContext);
|
||||
queryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
|
||||
|
||||
Map<MatchText, List<HanlpMapResult>> regTextMap =
|
||||
searchMatchStrategy.match(queryContext, originals, dataSetIds);
|
||||
|
||||
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
||||
|
||||
// 3.get the most matching data
|
||||
@@ -100,16 +103,16 @@ public class SearchServiceImpl implements SearchService {
|
||||
Set<SearchResult> searchResults = new LinkedHashSet();
|
||||
DataSetInfoStat dataSetInfoStat = NatureHelper.getDataSetStat(originals);
|
||||
|
||||
List<Long> possibleModels = getPossibleDataSets(queryReq, originals, dataSetInfoStat, dataSetIds);
|
||||
List<Long> possibleDataSets = getPossibleDataSets(queryReq, originals, dataSetInfoStat, dataSetIds);
|
||||
|
||||
// 5.1 priority dimension metric
|
||||
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), dataSetIdToName,
|
||||
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleDataSets), dataSetIdToName,
|
||||
searchTextEntry, searchResults);
|
||||
|
||||
// 5.2 process based on dimension values
|
||||
MatchText matchText = searchTextEntry.getKey();
|
||||
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleModels));
|
||||
log.debug("possibleModels:{},natureToNameMap:{}", possibleModels, natureToNameMap);
|
||||
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleDataSets));
|
||||
log.debug("possibleDataSets:{},natureToNameMap:{}", possibleDataSets, natureToNameMap);
|
||||
|
||||
for (Map.Entry<String, String> natureToNameEntry : natureToNameMap.entrySet()) {
|
||||
|
||||
@@ -123,23 +126,23 @@ public class SearchServiceImpl implements SearchService {
|
||||
}
|
||||
|
||||
private List<Long> getPossibleDataSets(QueryReq queryCtx, List<S2Term> originals,
|
||||
DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) {
|
||||
DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) {
|
||||
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
||||
return new ArrayList<>(dataSetIds);
|
||||
}
|
||||
|
||||
List<Long> possibleModels = NatureHelper.selectPossibleDataSets(originals);
|
||||
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
|
||||
|
||||
Long contextModel = chatContextService.getContextModel(queryCtx.getChatId());
|
||||
|
||||
log.debug("possibleModels:{},dataSetInfoStat:{},contextModel:{}",
|
||||
possibleModels, dataSetInfoStat, contextModel);
|
||||
log.debug("possibleDataSets:{},dataSetInfoStat:{},contextModel:{}",
|
||||
possibleDataSets, dataSetInfoStat, contextModel);
|
||||
|
||||
// If nothing is recognized or only metric are present, then add the contextModel.
|
||||
if (nothingOrOnlyMetric(dataSetInfoStat)) {
|
||||
return Lists.newArrayList(contextModel);
|
||||
}
|
||||
return possibleModels;
|
||||
return possibleDataSets;
|
||||
}
|
||||
|
||||
private boolean nothingOrOnlyMetric(DataSetInfoStat modelStat) {
|
||||
@@ -175,7 +178,6 @@ public class SearchServiceImpl implements SearchService {
|
||||
.subRecommend(wordName)
|
||||
.build();
|
||||
|
||||
|
||||
if (metricModelCount <= 0 && !existMetricAndDimension) {
|
||||
if (filterByQueryFilter(wordName, queryFilters)) {
|
||||
return searchResults;
|
||||
@@ -265,7 +267,7 @@ public class SearchServiceImpl implements SearchService {
|
||||
LinkedHashMap::new));
|
||||
}
|
||||
|
||||
private boolean searchMetricAndDimension(Set<Long> possibleModels, Map<Long, String> modelToName,
|
||||
private boolean searchMetricAndDimension(Set<Long> possibleDataSets, Map<Long, String> modelToName,
|
||||
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry, Set<SearchResult> searchResults) {
|
||||
boolean existMetric = false;
|
||||
log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry);
|
||||
@@ -277,7 +279,7 @@ public class SearchServiceImpl implements SearchService {
|
||||
List<ModelWithSemanticType> dimensionMetricClassIds = hanlpMapResult.getNatures().stream()
|
||||
.map(nature -> new ModelWithSemanticType(NatureHelper.getDataSetId(nature),
|
||||
NatureHelper.convertToElementType(nature)))
|
||||
.filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList());
|
||||
.filter(entry -> matchCondition(entry, possibleDataSets)).collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isEmpty(dimensionMetricClassIds)) {
|
||||
continue;
|
||||
@@ -296,22 +298,22 @@ public class SearchServiceImpl implements SearchService {
|
||||
//visibility to filter metrics
|
||||
searchResults.add(searchResult);
|
||||
}
|
||||
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", hanlpMapResult,
|
||||
dimensionMetricClassIds, possibleModels);
|
||||
log.info("parseResult:{},dimensionMetricClassIds:{},possibleDataSets:{}", hanlpMapResult,
|
||||
dimensionMetricClassIds, possibleDataSets);
|
||||
}
|
||||
log.info("searchMetricAndDimension searchResults:{}", searchResults);
|
||||
return existMetric;
|
||||
}
|
||||
|
||||
private boolean matchCondition(ModelWithSemanticType entry, Set<Long> possibleModels) {
|
||||
private boolean matchCondition(ModelWithSemanticType entry, Set<Long> possibleDataSets) {
|
||||
if (!(SchemaElementType.METRIC.equals(entry.getSchemaElementType()) || SchemaElementType.DIMENSION.equals(
|
||||
entry.getSchemaElementType()))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (CollectionUtils.isEmpty(possibleModels)) {
|
||||
if (CollectionUtils.isEmpty(possibleDataSets)) {
|
||||
return true;
|
||||
}
|
||||
return possibleModels.contains(entry.getModel());
|
||||
return possibleDataSets.contains(entry.getModel());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.tencent.supersonic.headless.server.service.impl;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
@@ -20,11 +22,6 @@ import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.service.QueryService;
|
||||
import com.tencent.supersonic.headless.server.service.TagMetaService;
|
||||
import com.tencent.supersonic.headless.server.service.TagQueryService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.ArrayList;
|
||||
@@ -34,8 +31,10 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -91,6 +90,7 @@ public class TagQueryServiceImpl implements TagQueryService {
|
||||
if (CollectionUtils.isEmpty(timeDimension)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// query date info from db
|
||||
String endDate = queryTagDateFromDbBySql(timeDimension.get(0), tag, user);
|
||||
DateConf dateConf = new DateConf();
|
||||
|
||||
@@ -105,6 +105,7 @@ public class TagConverter {
|
||||
if (!CollectionUtils.isEmpty(queryTagReq.getTagFilters())) {
|
||||
queryStructReq.setDimensionFilters(queryTagReq.getTagFilters());
|
||||
}
|
||||
queryStructReq.setQueryType(QueryType.TAG);
|
||||
QuerySqlReq querySqlReq = queryStructReq.convert();
|
||||
convert(querySqlReq, semanticSchemaResp, queryStatement, queryStructReq);
|
||||
QueryParam queryParam = new QueryParam();
|
||||
|
||||
@@ -445,6 +445,19 @@ public class ModelDemoDataLoader {
|
||||
tagDefineParam3s.setDependencies(new ArrayList<>(Arrays.asList(5)));
|
||||
tagReq3.setTagDefineParams(tagDefineParam3s);
|
||||
tagMetaService.create(tagReq3, user);
|
||||
|
||||
TagReq tagReq4 = new TagReq();
|
||||
tagReq4.setModelId(4L);
|
||||
tagReq4.setName("歌手名");
|
||||
tagReq4.setBizName("singer_name");
|
||||
tagReq4.setStatus(StatusEnum.ONLINE.getCode());
|
||||
tagReq4.setTypeEnum(TypeEnums.TAG);
|
||||
tagReq4.setTagDefineType(TagDefineType.DIMENSION);
|
||||
TagDefineParams tagDefineParam4s = new TagDefineParams();
|
||||
tagDefineParam4s.setExpr("singer_name");
|
||||
tagDefineParam4s.setDependencies(new ArrayList<>(Arrays.asList(7)));
|
||||
tagReq4.setTagDefineParams(tagDefineParam4s);
|
||||
tagMetaService.create(tagReq4, user);
|
||||
}
|
||||
|
||||
public void addMetric_uv() throws Exception {
|
||||
@@ -529,7 +542,7 @@ public class ModelDemoDataLoader {
|
||||
dataSetReq.setAdmins(Lists.newArrayList("admin", "jack"));
|
||||
List<DataSetModelConfig> dataSetModelConfigs = Lists.newArrayList(
|
||||
new DataSetModelConfig(4L, Lists.newArrayList(4L, 5L, 6L, 7L),
|
||||
Lists.newArrayList(5L, 6L, 7L), Lists.newArrayList(1L, 2L, 3L))
|
||||
Lists.newArrayList(5L, 6L, 7L), Lists.newArrayList(1L, 2L, 3L, 4L))
|
||||
);
|
||||
DataSetDetail dataSetDetail = new DataSetDetail();
|
||||
dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs);
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
周杰伦 _4_4_tv 100
|
||||
陈奕迅 _4_4_tv 100
|
||||
林俊杰 _4_4_tv 100
|
||||
张碧晨 _4_4_tv 100
|
||||
程响 _4_4_tv 100
|
||||
Taylor#Swift _4_4_tv 100
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
@@ -9,40 +8,12 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.tag.TagFilterQuery;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class TagTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_tag_query() throws Exception {
|
||||
MockConfiguration.mockTagAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("艺人周杰伦的播放量", DataUtils.tagAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricTagQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
|
||||
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
SchemaElement metric = SchemaElement.builder().name("播放量").build();
|
||||
expectedParseInfo.getMetrics().add(metric);
|
||||
|
||||
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateMode.RECENT, 7, period, startDay, endDay));
|
||||
expectedParseInfo.setQueryType(QueryType.METRIC);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_tag_list_filter() throws Exception {
|
||||
MockConfiguration.mockTagAgent(agentService);
|
||||
@@ -55,8 +26,6 @@ public class TagTest extends BaseTest {
|
||||
expectedResult.setQueryMode(TagFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
|
||||
|
||||
List<String> list = new ArrayList<>();
|
||||
list.add("流行");
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
|
||||
"流行", "风格", 2L);
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.BaseTest;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
|
||||
public class MapperTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void hanlp() throws Exception {
|
||||
|
||||
ChatParseReq chatParseReq = DataUtils.getChatParseReq(10, "艺人周杰伦的播放量");
|
||||
chatParseReq.setAgentId(DataUtils.tagAgentId);
|
||||
|
||||
QueryResult actualResult = submitNewChat("艺人周杰伦的播放量", DataUtils.tagAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricTagQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
SchemaElement metric = SchemaElement.builder().name("播放量").build();
|
||||
expectedParseInfo.getMetrics().add(metric);
|
||||
|
||||
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, 7, period, startDay, endDay));
|
||||
expectedParseInfo.setQueryType(QueryType.METRIC);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.tencent.supersonic.util;
|
||||
|
||||
import static java.time.LocalDate.now;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
@@ -13,12 +15,9 @@ import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import static java.time.LocalDate.now;
|
||||
|
||||
public class DataUtils {
|
||||
|
||||
public static final Integer metricAgentId = 1;
|
||||
@@ -140,7 +139,7 @@ public class DataUtils {
|
||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
|
||||
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_TAG", "METRIC_FILTER", "METRIC_MODEL",
|
||||
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ID", "METRIC_FILTER", "METRIC_MODEL",
|
||||
"TAG_DETAIL", "TAG_LIST_FILTER", "TAG_ID"));
|
||||
return ruleQueryTool;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
周杰伦 _4_4_tv 100
|
||||
陈奕迅 _4_4_tv 100
|
||||
林俊杰 _4_4_tv 100
|
||||
张碧晨 _4_4_tv 100
|
||||
程响 _4_4_tv 100
|
||||
Taylor#Swift _4_4_tv 100
|
||||
@@ -93,7 +93,7 @@ export type ChatContextType = {
|
||||
elementMatches: any[];
|
||||
nativeQuery: boolean;
|
||||
queryMode: string;
|
||||
queryType: 'METRIC' | 'METRIC_TAG' | 'ID' | 'TAG' | 'OTHER';
|
||||
queryType: 'METRIC' | 'METRIC_ID' | 'ID' | 'TAG' | 'OTHER';
|
||||
dimensionFilters: FilterItemType[];
|
||||
properties: any;
|
||||
sqlInfo: SqlInfoType;
|
||||
|
||||
@@ -152,11 +152,11 @@ const ParseTip: React.FC<Props> = ({
|
||||
<div className={itemValueClass}>{dataSet?.name}</div>
|
||||
</div>
|
||||
)}
|
||||
{(queryType === 'METRIC' || queryType === 'METRIC_TAG' || queryType === 'TAG') && (
|
||||
{(queryType === 'METRIC' || queryType === 'METRIC_ID' || queryType === 'TAG') && (
|
||||
<div className={`${prefixCls}-tip-item`}>
|
||||
<div className={`${prefixCls}-tip-item-name`}>查询模式:</div>
|
||||
<div className={itemValueClass}>
|
||||
{queryType === 'METRIC' || queryType === 'METRIC_TAG' ? '指标模式' : '标签模式'}
|
||||
{queryType === 'METRIC' || queryType === 'METRIC_ID' ? '指标模式' : '标签模式'}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -45,7 +45,7 @@ const Message: React.FC<Props> = ({
|
||||
e.stopPropagation();
|
||||
}}
|
||||
>
|
||||
{(queryMode === 'METRIC_TAG' || queryMode === 'TAG_DETAIL') &&
|
||||
{(queryMode === 'METRIC_ID' || queryMode === 'TAG_DETAIL') &&
|
||||
entityInfoList.length > 0 && (
|
||||
<div className={`${prefixCls}-info-bar`}>
|
||||
<div className={`${prefixCls}-main-entity-info`}>
|
||||
|
||||
@@ -233,7 +233,7 @@ const ChatMsg: React.FC<Props> = ({ queryId, data, chartIndex, triggerResize })
|
||||
?.name;
|
||||
|
||||
const isEntityMode =
|
||||
(queryMode === 'TAG_LIST_FILTER' || queryMode === 'METRIC_TAG') &&
|
||||
(queryMode === 'TAG_LIST_FILTER' || queryMode === 'METRIC_ID') &&
|
||||
typeof entityId === 'string' &&
|
||||
entityName !== undefined;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user