(improvement)(Chat) Fixed the issue of ineffective filtering in mapper detectDataSetIds, resolved the autocomplete feature, and changed METRIC_TAG to METRIC_ID. (#819)

This commit is contained in:
lexluo09
2024-03-14 16:58:41 +08:00
committed by GitHub
parent 901770f02c
commit 30ee64efec
25 changed files with 148 additions and 168 deletions

View File

@@ -13,15 +13,15 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.util.CollectionUtils;
/**
* MetricRecommendProcessor fills recommended metrics based on embedding similarity.
@@ -49,7 +49,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
.filterCondition(filterCondition).queryEmbeddings(null).build();
MetaEmbeddingService metaEmbeddingService = ContextUtils.getBean(MetaEmbeddingService.class);
List<RetrieveQueryResult> retrieveQueryResults =
metaEmbeddingService.retrieveQuery(retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>());
metaEmbeddingService.retrieveQuery(retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>(),
new HashSet<>());
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
}

View File

@@ -56,21 +56,24 @@ public class KnowledgeService {
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
}
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds) {
return prefixSearchByModel(key, limit, modelIdToDataSetIds);
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds) {
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds);
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds) {
return suffixSearchByModel(key, limit, modelIdToDataSetIds.keySet());
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Set<Long> models) {
return SearchService.suffixSearch(key, limit, models);
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
}

View File

@@ -7,12 +7,7 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
@@ -21,6 +16,11 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
@@ -31,9 +31,9 @@ public class MetaEmbeddingService {
private EmbeddingConfig embeddingConfig;
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
Map<Long, List<Long>> modelIdToDataSetIds) {
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
// dataSetIds->modelIds
Set<Long> allModels = modelIdToDataSetIds.keySet();
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
Map<String, String> filterCondition = new HashMap<>();

View File

@@ -41,13 +41,15 @@ public class SearchService {
* @param key
* @return
*/
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToViewIds) {
return prefixSearch(key, limit, trie, modelIdToViewIds);
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
return prefixSearch(key, limit, trie, modelIdToDataSetIds, detectDataSetIds);
}
public static List<HanlpMapResult> prefixSearch(String key, int limit, BinTrie<List<String>> binTrie,
Map<Long, List<Long>> modelIdToViewIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToViewIds.keySet());
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie,
modelIdToDataSetIds, detectDataSetIds);
List<HanlpMapResult> hanlpMapResults = result.stream().map(
entry -> {
String name = entry.getKey().replace("#", " ");
@@ -58,7 +60,7 @@ public class SearchService {
.collect(Collectors.toList());
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
List<String> natures = hanlpMapResult.getNatures().stream()
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToViewIds))
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds))
.flatMap(Collection::stream).collect(Collectors.toList());
hanlpMapResult.setNatures(natures);
}
@@ -70,14 +72,18 @@ public class SearchService {
* @param key
* @return
*/
public static List<HanlpMapResult> suffixSearch(String key, int limit, Set<Long> detectModelIds) {
public static List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
String reverseDetectSegment = StringUtils.reverse(key);
return suffixSearch(reverseDetectSegment, limit, suffixTrie, detectModelIds);
return suffixSearch(reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds);
}
public static List<HanlpMapResult> suffixSearch(String key, int limit, BinTrie<List<String>> binTrie,
Set<Long> detectModelIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, detectModelIds);
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToDataSetIds,
detectDataSetIds);
return result.stream().map(
entry -> {
String name = entry.getKey().replace("#", " ");
@@ -93,7 +99,10 @@ public class SearchService {
}
private static Set<Map.Entry<String, List<String>>> prefixSearchLimit(String key, int limit,
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Long> detectModelIds = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
key = key.toLowerCase();
Set<Map.Entry<String, List<String>>> entrySet = new TreeSet<Map.Entry<String, List<String>>>();

View File

@@ -6,17 +6,17 @@ import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.chat.knowledge.DataSetInfoStat;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
/**
* nature parse helper
@@ -220,4 +220,18 @@ public class NatureHelper {
return 0L;
}
public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
Set<Long> detectModelIds = modelIdToDataSetIds.keySet();
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
detectModelIds = modelIdToDataSetIds.entrySet().stream().filter(entry -> {
List<Long> dataSetIds = entry.getValue().stream().filter(detectDataSetIds::contains)
.collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dataSetIds)) {
return true;
}
return false;
}).map(entry -> entry.getKey()).collect(Collectors.toSet());
}
return detectModelIds;
}
}

View File

@@ -49,7 +49,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
String detectSegment, int offset) {
}
@@ -77,10 +77,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
// step2. retrieveQuery by detectSegment
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
retrieveQuery, embeddingNumber, modelIdToDataSetIds);
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;

View File

@@ -34,7 +34,9 @@ public class EntityMapper extends BaseMapper {
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.TAG_VALUE.equals(schemaElementMatch.getElement().getType()
))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {

View File

@@ -39,7 +39,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
@@ -65,11 +65,11 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds())
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds())
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults.addAll(suffixHanlpMapResults);

View File

@@ -33,7 +33,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
Set<Long> detectDataSetIds) {
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
@@ -58,9 +58,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds());
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds());
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricIdQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.AbstractMap;
@@ -94,7 +94,7 @@ public class ContextInheritParser implements SemanticParser {
return matches.stream().anyMatch(m -> {
SchemaElementType type = m.getElement().getType();
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
&& !(ruleQuery instanceof MetricTagQuery)) {
&& !(ruleQuery instanceof MetricIdQuery)) {
return types.contains(type);
}
return type.equals(matchType);

View File

@@ -1,31 +1,31 @@
package com.tencent.supersonic.headless.core.chat.query.rule.metric;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Slf4j
@Component
public class MetricTagQuery extends MetricSemanticQuery {
public class MetricIdQuery extends MetricSemanticQuery {
public static final String QUERY_MODE = "METRIC_TAG";
public static final String QUERY_MODE = "METRIC_ID";
public MetricTagQuery() {
public MetricIdQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1)
.addOption(ENTITY, REQUIRED, AT_LEAST, 1);

View File

@@ -193,7 +193,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
SemanticParseInfo parseInfo, User user) throws Exception {
SemanticParseInfo parseInfo, User user) throws Exception {
SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user);
QueryResult queryResult = new QueryResult();
if (queryResp != null) {
@@ -591,7 +591,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
//search from prefixSearch
List<HanlpMapResult> hanlpMapResultList = knowledgeService.prefixSearch(dimensionValueReq.getValue(),
2000, modelIdToDataSetIds);
2000, modelIdToDataSetIds, dataSetIds);
HanlpHelper.transLetterOriginal(hanlpMapResultList);
return hanlpMapResultList.stream()
.filter(o -> {

View File

@@ -78,8 +78,11 @@ public class SearchServiceImpl implements SearchService {
QueryContext queryContext = new QueryContext();
BeanUtils.copyProperties(queryReq, queryContext);
queryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryContext, originals, dataSetIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
// 3.get the most matching data
@@ -100,16 +103,16 @@ public class SearchServiceImpl implements SearchService {
Set<SearchResult> searchResults = new LinkedHashSet();
DataSetInfoStat dataSetInfoStat = NatureHelper.getDataSetStat(originals);
List<Long> possibleModels = getPossibleDataSets(queryReq, originals, dataSetInfoStat, dataSetIds);
List<Long> possibleDataSets = getPossibleDataSets(queryReq, originals, dataSetInfoStat, dataSetIds);
// 5.1 priority dimension metric
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), dataSetIdToName,
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleDataSets), dataSetIdToName,
searchTextEntry, searchResults);
// 5.2 process based on dimension values
MatchText matchText = searchTextEntry.getKey();
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleModels));
log.debug("possibleModels:{},natureToNameMap:{}", possibleModels, natureToNameMap);
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleDataSets));
log.debug("possibleDataSets:{},natureToNameMap:{}", possibleDataSets, natureToNameMap);
for (Map.Entry<String, String> natureToNameEntry : natureToNameMap.entrySet()) {
@@ -123,23 +126,23 @@ public class SearchServiceImpl implements SearchService {
}
private List<Long> getPossibleDataSets(QueryReq queryCtx, List<S2Term> originals,
DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) {
DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) {
if (CollectionUtils.isNotEmpty(dataSetIds)) {
return new ArrayList<>(dataSetIds);
}
List<Long> possibleModels = NatureHelper.selectPossibleDataSets(originals);
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
Long contextModel = chatContextService.getContextModel(queryCtx.getChatId());
log.debug("possibleModels:{},dataSetInfoStat:{},contextModel:{}",
possibleModels, dataSetInfoStat, contextModel);
log.debug("possibleDataSets:{},dataSetInfoStat:{},contextModel:{}",
possibleDataSets, dataSetInfoStat, contextModel);
// If nothing is recognized or only metric are present, then add the contextModel.
if (nothingOrOnlyMetric(dataSetInfoStat)) {
return Lists.newArrayList(contextModel);
}
return possibleModels;
return possibleDataSets;
}
private boolean nothingOrOnlyMetric(DataSetInfoStat modelStat) {
@@ -175,7 +178,6 @@ public class SearchServiceImpl implements SearchService {
.subRecommend(wordName)
.build();
if (metricModelCount <= 0 && !existMetricAndDimension) {
if (filterByQueryFilter(wordName, queryFilters)) {
return searchResults;
@@ -265,7 +267,7 @@ public class SearchServiceImpl implements SearchService {
LinkedHashMap::new));
}
private boolean searchMetricAndDimension(Set<Long> possibleModels, Map<Long, String> modelToName,
private boolean searchMetricAndDimension(Set<Long> possibleDataSets, Map<Long, String> modelToName,
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry, Set<SearchResult> searchResults) {
boolean existMetric = false;
log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry);
@@ -277,7 +279,7 @@ public class SearchServiceImpl implements SearchService {
List<ModelWithSemanticType> dimensionMetricClassIds = hanlpMapResult.getNatures().stream()
.map(nature -> new ModelWithSemanticType(NatureHelper.getDataSetId(nature),
NatureHelper.convertToElementType(nature)))
.filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList());
.filter(entry -> matchCondition(entry, possibleDataSets)).collect(Collectors.toList());
if (CollectionUtils.isEmpty(dimensionMetricClassIds)) {
continue;
@@ -296,22 +298,22 @@ public class SearchServiceImpl implements SearchService {
//visibility to filter metrics
searchResults.add(searchResult);
}
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", hanlpMapResult,
dimensionMetricClassIds, possibleModels);
log.info("parseResult:{},dimensionMetricClassIds:{},possibleDataSets:{}", hanlpMapResult,
dimensionMetricClassIds, possibleDataSets);
}
log.info("searchMetricAndDimension searchResults:{}", searchResults);
return existMetric;
}
private boolean matchCondition(ModelWithSemanticType entry, Set<Long> possibleModels) {
private boolean matchCondition(ModelWithSemanticType entry, Set<Long> possibleDataSets) {
if (!(SchemaElementType.METRIC.equals(entry.getSchemaElementType()) || SchemaElementType.DIMENSION.equals(
entry.getSchemaElementType()))) {
return false;
}
if (CollectionUtils.isEmpty(possibleModels)) {
if (CollectionUtils.isEmpty(possibleDataSets)) {
return true;
}
return possibleModels.contains(entry.getModel());
return possibleDataSets.contains(entry.getModel());
}
}

View File

@@ -1,5 +1,7 @@
package com.tencent.supersonic.headless.server.service.impl;
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf;
@@ -20,11 +22,6 @@ import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.service.TagMetaService;
import com.tencent.supersonic.headless.server.service.TagQueryService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
@@ -34,8 +31,10 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service
@Slf4j
@@ -91,6 +90,7 @@ public class TagQueryServiceImpl implements TagQueryService {
if (CollectionUtils.isEmpty(timeDimension)) {
return;
}
// query date info from db
String endDate = queryTagDateFromDbBySql(timeDimension.get(0), tag, user);
DateConf dateConf = new DateConf();

View File

@@ -105,6 +105,7 @@ public class TagConverter {
if (!CollectionUtils.isEmpty(queryTagReq.getTagFilters())) {
queryStructReq.setDimensionFilters(queryTagReq.getTagFilters());
}
queryStructReq.setQueryType(QueryType.TAG);
QuerySqlReq querySqlReq = queryStructReq.convert();
convert(querySqlReq, semanticSchemaResp, queryStatement, queryStructReq);
QueryParam queryParam = new QueryParam();

View File

@@ -445,6 +445,19 @@ public class ModelDemoDataLoader {
tagDefineParam3s.setDependencies(new ArrayList<>(Arrays.asList(5)));
tagReq3.setTagDefineParams(tagDefineParam3s);
tagMetaService.create(tagReq3, user);
TagReq tagReq4 = new TagReq();
tagReq4.setModelId(4L);
tagReq4.setName("歌手名");
tagReq4.setBizName("singer_name");
tagReq4.setStatus(StatusEnum.ONLINE.getCode());
tagReq4.setTypeEnum(TypeEnums.TAG);
tagReq4.setTagDefineType(TagDefineType.DIMENSION);
TagDefineParams tagDefineParam4s = new TagDefineParams();
tagDefineParam4s.setExpr("singer_name");
tagDefineParam4s.setDependencies(new ArrayList<>(Arrays.asList(7)));
tagReq4.setTagDefineParams(tagDefineParam4s);
tagMetaService.create(tagReq4, user);
}
public void addMetric_uv() throws Exception {
@@ -529,7 +542,7 @@ public class ModelDemoDataLoader {
dataSetReq.setAdmins(Lists.newArrayList("admin", "jack"));
List<DataSetModelConfig> dataSetModelConfigs = Lists.newArrayList(
new DataSetModelConfig(4L, Lists.newArrayList(4L, 5L, 6L, 7L),
Lists.newArrayList(5L, 6L, 7L), Lists.newArrayList(1L, 2L, 3L))
Lists.newArrayList(5L, 6L, 7L), Lists.newArrayList(1L, 2L, 3L, 4L))
);
DataSetDetail dataSetDetail = new DataSetDetail();
dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs);

View File

@@ -0,0 +1,6 @@
周杰伦 _4_4_tv 100
陈奕迅 _4_4_tv 100
林俊杰 _4_4_tv 100
张碧晨 _4_4_tv 100
程响 _4_4_tv 100
Taylor#Swift _4_4_tv 100

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
@@ -9,40 +8,12 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.tag.TagFilterQuery;
import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
public class TagTest extends BaseTest {
@Test
public void queryTest_metric_tag_query() throws Exception {
MockConfiguration.mockTagAgent(agentService);
QueryResult actualResult = submitNewChat("艺人周杰伦的播放量", DataUtils.tagAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricTagQuery.QUERY_MODE);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
SchemaElement metric = SchemaElement.builder().name("播放量").build();
expectedParseInfo.getMetrics().add(metric);
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateMode.RECENT, 7, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
}
@Test
public void queryTest_tag_list_filter() throws Exception {
MockConfiguration.mockTagAgent(agentService);
@@ -55,8 +26,6 @@ public class TagTest extends BaseTest {
expectedResult.setQueryMode(TagFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
List<String> list = new ArrayList<>();
list.add("流行");
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
"流行", "风格", 2L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);

View File

@@ -1,47 +0,0 @@
package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.BaseTest;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.core.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Test;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
public class MapperTest extends BaseTest {
@Test
public void hanlp() throws Exception {
ChatParseReq chatParseReq = DataUtils.getChatParseReq(10, "艺人周杰伦的播放量");
chatParseReq.setAgentId(DataUtils.tagAgentId);
QueryResult actualResult = submitNewChat("艺人周杰伦的播放量", DataUtils.tagAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricTagQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
SchemaElement metric = SchemaElement.builder().name("播放量").build();
expectedParseInfo.getMetrics().add(metric);
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, 7, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
}
}

View File

@@ -1,5 +1,7 @@
package com.tencent.supersonic.util;
import static java.time.LocalDate.now;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
@@ -13,12 +15,9 @@ import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import java.util.HashSet;
import java.util.Set;
import static java.time.LocalDate.now;
public class DataUtils {
public static final Integer metricAgentId = 1;
@@ -140,7 +139,7 @@ public class DataUtils {
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_TAG", "METRIC_FILTER", "METRIC_MODEL",
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ID", "METRIC_FILTER", "METRIC_MODEL",
"TAG_DETAIL", "TAG_LIST_FILTER", "TAG_ID"));
return ruleQueryTool;
}

View File

@@ -0,0 +1,6 @@
周杰伦 _4_4_tv 100
陈奕迅 _4_4_tv 100
林俊杰 _4_4_tv 100
张碧晨 _4_4_tv 100
程响 _4_4_tv 100
Taylor#Swift _4_4_tv 100

View File

@@ -93,7 +93,7 @@ export type ChatContextType = {
elementMatches: any[];
nativeQuery: boolean;
queryMode: string;
queryType: 'METRIC' | 'METRIC_TAG' | 'ID' | 'TAG' | 'OTHER';
queryType: 'METRIC' | 'METRIC_ID' | 'ID' | 'TAG' | 'OTHER';
dimensionFilters: FilterItemType[];
properties: any;
sqlInfo: SqlInfoType;

View File

@@ -152,11 +152,11 @@ const ParseTip: React.FC<Props> = ({
<div className={itemValueClass}>{dataSet?.name}</div>
</div>
)}
{(queryType === 'METRIC' || queryType === 'METRIC_TAG' || queryType === 'TAG') && (
{(queryType === 'METRIC' || queryType === 'METRIC_ID' || queryType === 'TAG') && (
<div className={`${prefixCls}-tip-item`}>
<div className={`${prefixCls}-tip-item-name`}></div>
<div className={itemValueClass}>
{queryType === 'METRIC' || queryType === 'METRIC_TAG' ? '指标模式' : '标签模式'}
{queryType === 'METRIC' || queryType === 'METRIC_ID' ? '指标模式' : '标签模式'}
</div>
</div>
)}

View File

@@ -45,7 +45,7 @@ const Message: React.FC<Props> = ({
e.stopPropagation();
}}
>
{(queryMode === 'METRIC_TAG' || queryMode === 'TAG_DETAIL') &&
{(queryMode === 'METRIC_ID' || queryMode === 'TAG_DETAIL') &&
entityInfoList.length > 0 && (
<div className={`${prefixCls}-info-bar`}>
<div className={`${prefixCls}-main-entity-info`}>

View File

@@ -233,7 +233,7 @@ const ChatMsg: React.FC<Props> = ({ queryId, data, chartIndex, triggerResize })
?.name;
const isEntityMode =
(queryMode === 'TAG_LIST_FILTER' || queryMode === 'METRIC_TAG') &&
(queryMode === 'TAG_LIST_FILTER' || queryMode === 'METRIC_ID') &&
typeof entityId === 'string' &&
entityName !== undefined;