[improvement](chat) support query/search filter by web domainId and mapping add frequency/detectWord in mapping and metric dimensions orders filter duplicates

This commit is contained in:
lexluo
2023-06-15 18:15:44 +08:00
parent 1fd08be2cd
commit b6f0df40a9
57 changed files with 1040 additions and 332 deletions

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.api.response.QueryResultResp;
import com.tencent.supersonic.chat.domain.dataobject.ChatDO;
import com.tencent.supersonic.chat.domain.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.domain.dataobject.QueryDO;
import com.tencent.supersonic.chat.domain.pojo.chat.ChatQueryVO;
import com.tencent.supersonic.chat.domain.pojo.chat.PageQueryInfoReq;
@@ -130,4 +131,14 @@ public class ChatServiceImpl implements ChatService {
chatQueryRepository.createChatQuery(queryResponse, queryContext, chatCtx);
}
@Override
public ChatQueryDO getLastQuery(long chatId) {
return chatQueryRepository.getLastChatQuery(chatId);
}
@Override
public int updateQuery(ChatQueryDO chatQueryDO) {
return chatQueryRepository.updateChatQuery(chatQueryDO);
}
}

View File

@@ -20,6 +20,8 @@ import com.tencent.supersonic.chat.domain.utils.DefaultSemanticInternalUtils;
import com.tencent.supersonic.chat.domain.utils.SchemaInfoConverter;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.SchemaItem;
import java.util.LinkedHashSet;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
@@ -137,7 +139,7 @@ public class DomainEntityService {
chatFilter.setValue(String.valueOf(entity));
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
chatFilter.setBizName(getEntityPrimaryName(domainInfo));
List<Filter> chatFilters = new ArrayList<>();
Set<Filter> chatFilters = new LinkedHashSet();
chatFilters.add(chatFilter);
semanticParseInfo.setDimensionFilters(chatFilters);
@@ -167,8 +169,8 @@ public class DomainEntityService {
}
}
private List<SchemaItem> getDimensions(EntityInfo domainInfo) {
List<SchemaItem> dimensions = new ArrayList<>();
private Set<SchemaItem> getDimensions(EntityInfo domainInfo) {
Set<SchemaItem> dimensions = new LinkedHashSet();
for (DataInfo mainEntityDimension : domainInfo.getDimensions()) {
SchemaItem dimension = new SchemaItem();
dimension.setBizName(mainEntityDimension.getBizName());
@@ -186,8 +188,8 @@ public class DomainEntityService {
return entryKey;
}
private List<SchemaItem> getMetrics(EntityInfo domainInfo) {
List<SchemaItem> metrics = new ArrayList<>();
private Set<SchemaItem> getMetrics(EntityInfo domainInfo) {
Set<SchemaItem> metrics = new LinkedHashSet();
for (DataInfo metricValue : domainInfo.getMetrics()) {
SchemaItem metric = new SchemaItem();
metric.setBizName(metricValue.getBizName());

View File

@@ -3,6 +3,8 @@ package com.tencent.supersonic.chat.application;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.api.response.QueryResultResp;
@@ -18,6 +20,9 @@ import com.tencent.supersonic.chat.domain.service.QueryService;
import com.tencent.supersonic.chat.domain.utils.SchemaInfoConverter;
import com.tencent.supersonic.common.util.json.JsonUtil;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;

View File

@@ -4,8 +4,8 @@ import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.api.service.SemanticLayer;
import com.tencent.supersonic.chat.application.knowledge.NatureHelper;
import com.tencent.supersonic.chat.application.knowledge.WordNatureService;
import com.tencent.supersonic.chat.application.mapper.SearchMatchStrategy;
import com.tencent.supersonic.chat.domain.pojo.search.DomainInfoStat;
import com.tencent.supersonic.chat.domain.pojo.search.DomainWithSemanticType;
@@ -15,12 +15,10 @@ import com.tencent.supersonic.chat.domain.pojo.semantic.DomainInfos;
import com.tencent.supersonic.chat.domain.service.ChatService;
import com.tencent.supersonic.chat.domain.service.SearchService;
import com.tencent.supersonic.chat.domain.utils.NatureConverter;
import com.tencent.supersonic.chat.domain.utils.SchemaInfoConverter;
import com.tencent.supersonic.common.nlp.ItemDO;
import com.tencent.supersonic.common.nlp.MapResult;
import com.tencent.supersonic.common.nlp.NatureType;
import com.tencent.supersonic.common.nlp.WordNature;
import com.tencent.supersonic.knowledge.application.online.BaseWordNature;
import com.tencent.supersonic.knowledge.infrastructure.nlp.HanlpHelper;
import java.util.ArrayList;
import java.util.Comparator;
@@ -47,12 +45,11 @@ import org.springframework.stereotype.Service;
@Service
public class SearchServiceImpl implements SearchService {
private final Logger logger = LoggerFactory.getLogger(SearchServiceImpl.class);
private static final Logger LOGGER = LoggerFactory.getLogger(SearchServiceImpl.class);
@Autowired
private SemanticLayer semanticLayer;
private WordNatureService wordNatureService;
@Autowired
private ChatService chatService;
@Autowired
private SearchMatchStrategy searchMatchStrategy;
@@ -63,13 +60,16 @@ public class SearchServiceImpl implements SearchService {
public List<SearchResult> search(QueryContextReq queryCtx) {
String queryText = queryCtx.getQueryText();
// 1.get meta info
DomainInfos domainInfosDb = SchemaInfoConverter.convert(semanticLayer.getDomainSchemaInfo(new ArrayList<>()));
DomainInfos domainInfosDb = wordNatureService.getCache().getUnchecked("");
List<ItemDO> metricsDb = domainInfosDb.getMetrics();
final Map<Integer, String> domainToName = domainInfosDb.getDomainToName();
// 2.detect by segment
List<Term> originals = HanlpHelper.getSegment().seg(queryText).stream().collect(Collectors.toList());
Map<MatchText, List<MapResult>> regTextMap = searchMatchStrategy.matchWithMatchText(queryText, originals);
List<Term> originals = HanlpHelper.getSegment().seg(queryText.toLowerCase()).stream()
.collect(Collectors.toList());
Map<MatchText, List<MapResult>> regTextMap = searchMatchStrategy.matchWithMatchText(queryText, originals,
queryCtx.getDomainId());
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
// 3.get the most matching data
Optional<Entry<MatchText, List<MapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
.stream()
@@ -77,28 +77,28 @@ public class SearchServiceImpl implements SearchService {
.reduce((entry1, entry2) ->
entry1.getKey().getDetectSegment().length() >= entry2.getKey().getDetectSegment().length()
? entry1 : entry2);
logger.debug("mostSimilarSearchResult:{}", mostSimilarSearchResult);
LOGGER.debug("mostSimilarSearchResult:{}", mostSimilarSearchResult);
// 4.optimize the results after the query
if (!mostSimilarSearchResult.isPresent()) {
logger.info("unable to find any information through search , queryCtx:{}", queryCtx);
LOGGER.info("unable to find any information through search , queryCtx:{}", queryCtx);
return Lists.newArrayList();
}
Map.Entry<MatchText, List<MapResult>> searchTextEntry = mostSimilarSearchResult.get();
logger.info("searchTextEntry:{},queryCtx:{}", searchTextEntry, queryCtx);
LOGGER.info("searchTextEntry:{},queryCtx:{}", searchTextEntry, queryCtx);
Set<SearchResult> searchResults = new LinkedHashSet();
DomainInfoStat domainStat = NatureHelper.getDomainStat(originals);
List<Integer> possibleDomains = getPossibleDomains(queryCtx, originals, domainStat);
List<Integer> possibleDomains = getPossibleDomains(queryCtx, originals, domainStat, queryCtx.getDomainId());
// 4.1 priority dimension metric
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleDomains), domainToName,
searchTextEntry,
searchResults);
searchTextEntry, searchResults);
// 4.2 process based on dimension values
MatchText matchText = searchTextEntry.getKey();
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry);
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleDomains));
LOGGER.debug("possibleDomains:{},natureToNameMap:{}", possibleDomains, natureToNameMap);
for (Map.Entry<String, String> natureToNameEntry : natureToNameMap.entrySet()) {
searchDimensionValue(metricsDb, domainToName, domainStat.getMetricDomainCount(), searchResults,
@@ -108,12 +108,19 @@ public class SearchServiceImpl implements SearchService {
}
private List<Integer> getPossibleDomains(QueryContextReq queryCtx, List<Term> originals,
DomainInfoStat domainStat) {
DomainInfoStat domainStat, Integer webDomainId) {
if (Objects.nonNull(webDomainId) && webDomainId > 0) {
List<Integer> result = new ArrayList<>();
result.add(webDomainId);
return result;
}
List<Integer> possibleDomains = NatureHelper.selectPossibleDomains(originals);
Long contextDomain = chatService.getContextDomain(queryCtx.getChatId());
logger.debug("possibleDomains:{},domainStat:{},contextDomain:{}", possibleDomains, domainStat, contextDomain);
LOGGER.debug("possibleDomains:{},domainStat:{},contextDomain:{}", possibleDomains, domainStat, contextDomain);
// If nothing is recognized or only metric are present, then add the contextDomain.
if (nothingOrOnlyMetric(domainStat) && effectiveDomain(contextDomain)) {
@@ -195,16 +202,25 @@ public class SearchServiceImpl implements SearchService {
* @param recommendTextListEntry
* @return
*/
private Map<String, String> getNatureToNameMap(Map.Entry<MatchText, List<MapResult>> recommendTextListEntry) {
private Map<String, String> getNatureToNameMap(Map.Entry<MatchText, List<MapResult>> recommendTextListEntry,
Set<Integer> possibleDomains) {
List<MapResult> recommendValues = recommendTextListEntry.getValue();
return recommendValues.stream()
.flatMap(entry -> entry.getNatures().stream().map(nature -> {
WordNature posDO = new WordNature();
posDO.setWord(entry.getName());
posDO.setNature(nature);
return posDO;
}
)).sorted(Comparator.comparingInt(a -> a.getWord().length()))
.flatMap(entry -> entry.getNatures().stream()
.filter(nature -> {
if (CollectionUtils.isEmpty(possibleDomains)) {
return true;
}
Integer domain = NatureHelper.getDomain(nature);
return possibleDomains.contains(domain);
})
.map(nature -> {
WordNature posDO = new WordNature();
posDO.setWord(entry.getName());
posDO.setNature(nature);
return posDO;
}
)).sorted(Comparator.comparingInt(a -> a.getWord().length()))
.collect(Collectors.toMap(WordNature::getNature, WordNature::getWord, (value1, value2) -> value1,
LinkedHashMap::new));
}
@@ -233,7 +249,7 @@ public class SearchServiceImpl implements SearchService {
domainToName.get(domain), domain, semanticType));
}
}
logger.info("parseResult:{},dimensionMetricClassIds:{},possibleDomains:{}", mapResult,
LOGGER.info("parseResult:{},dimensionMetricClassIds:{},possibleDomains:{}", mapResult,
dimensionMetricClassIds, possibleDomains);
}
return existMetric;

View File

@@ -10,7 +10,6 @@ import org.springframework.context.ApplicationListener;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
@Slf4j
@@ -23,7 +22,6 @@ public class ApplicationStartedInit implements ApplicationListener<ApplicationSt
@Autowired
private WordNatureService wordNatureService;
private List<WordNature> preWordNatures = new ArrayList<>();
@Override
public void onApplicationEvent(ApplicationStartedEvent event) {
@@ -32,7 +30,7 @@ public class ApplicationStartedInit implements ApplicationListener<ApplicationSt
List<WordNature> wordNatures = wordNatureService.getAllWordNature();
this.preWordNatures = wordNatures;
wordNatureService.setPreWordNatures(wordNatures);
onlineKnowledgeService.reloadAllData(wordNatures);
@@ -51,14 +49,15 @@ public class ApplicationStartedInit implements ApplicationListener<ApplicationSt
try {
List<WordNature> wordNatures = wordNatureService.getAllWordNature();
List<WordNature> preWordNatures = wordNatureService.getPreWordNatures();
if (CollectionUtils.isEqualCollection(wordNatures, preWordNatures)) {
log.debug("wordNatures is not change, reloadKnowledge end");
return;
}
this.preWordNatures = wordNatures;
wordNatureService.setPreWordNatures(wordNatures);
onlineKnowledgeService.updateOnlineKnowledge(wordNatureService.getAllWordNature());
wordNatureService.getCache().refresh("");
} catch (Exception e) {
log.error("reloadKnowledge error", e);

View File

@@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory;
public class NatureHelper {
private static final Logger LOGGER = LoggerFactory.getLogger(NatureHelper.class);
private static boolean isDomainOrEntity(Term term, Integer domain) {
return (NatureType.NATURE_SPILT + domain).equals(term.nature.toString()) || term.nature.toString()
.endsWith(NatureType.ENTITY.getType());
@@ -96,7 +97,7 @@ public class NatureHelper {
/**
* Get the number of types of class parts of speech
* classId -> (nature , natureCount)
* domainId -> (nature , natureCount)
*
* @param terms
* @return

View File

@@ -1,5 +1,8 @@
package com.tencent.supersonic.chat.application.knowledge;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.tencent.supersonic.chat.api.service.SemanticLayer;
import com.tencent.supersonic.chat.domain.pojo.semantic.DomainInfos;
import com.tencent.supersonic.chat.domain.utils.SchemaInfoConverter;
@@ -9,6 +12,7 @@ import com.tencent.supersonic.common.nlp.WordNature;
import com.tencent.supersonic.knowledge.application.online.WordNatureStrategyFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
@@ -21,10 +25,25 @@ import org.springframework.stereotype.Service;
@Service
public class WordNatureService {
private final Logger logger = LoggerFactory.getLogger(WordNatureService.class);
private static final Logger LOGGER = LoggerFactory.getLogger(WordNatureService.class);
@Autowired
private SemanticLayer semanticLayer;
private static final Integer META_CACHE_TIME = 5;
private List<WordNature> preWordNatures = new ArrayList<>();
private LoadingCache<String, DomainInfos> cache = CacheBuilder.newBuilder()
.expireAfterWrite(META_CACHE_TIME, TimeUnit.MINUTES)
.build(
new CacheLoader<String, DomainInfos>() {
@Override
public DomainInfos load(String key) {
LOGGER.info("load getDomainSchemaInfo cache [{}]", key);
return SchemaInfoConverter.convert(semanticLayer.getDomainSchemaInfo(new ArrayList<>()));
}
}
);
public List<WordNature> getAllWordNature() {
@@ -45,7 +64,19 @@ public class WordNatureService {
private void addNatureToResult(NatureType value, List<ItemDO> metas, List<WordNature> natures) {
List<WordNature> natureList = WordNatureStrategyFactory.get(value).getWordNatureList(metas);
logger.debug("nature type:{} , nature size:{}", value.name(), natureList.size());
LOGGER.debug("nature type:{} , nature size:{}", value.name(), natureList.size());
natures.addAll(natureList);
}
public List<WordNature> getPreWordNatures() {
return preWordNatures;
}
public void setPreWordNatures(List<WordNature> preWordNatures) {
this.preWordNatures = preWordNatures;
}
public LoadingCache<String, DomainInfos> getCache() {
return cache;
}
}

View File

@@ -12,6 +12,7 @@ import com.tencent.supersonic.common.nlp.MapResult;
import com.tencent.supersonic.common.nlp.NatureType;
import com.tencent.supersonic.common.util.context.ContextUtils;
import com.tencent.supersonic.common.util.json.JsonUtil;
import com.tencent.supersonic.knowledge.application.online.BaseWordNature;
import com.tencent.supersonic.knowledge.application.online.WordNatureStrategyFactory;
import com.tencent.supersonic.knowledge.infrastructure.nlp.HanlpHelper;
import java.util.ArrayList;
@@ -28,18 +29,22 @@ public class HanlpSchemaMapper implements SchemaMapper {
private static final Logger LOGGER = LoggerFactory.getLogger(HanlpSchemaMapper.class);
@Override
public void map(QueryContextReq searchCtx) {
public void map(QueryContextReq queryContext) {
List<Term> terms = HanlpHelper.getSegment().seg(queryContext.getQueryText().toLowerCase()).stream()
.collect(Collectors.toList());
List<Term> terms = HanlpHelper.getSegment().seg(searchCtx.getQueryText()).stream().collect(Collectors.toList());
terms.forEach(
item -> LOGGER.info("word:{},nature:{},frequency:{}", item.word, item.nature.toString(), item.frequency)
item -> LOGGER.info("word:{},nature:{},frequency:{}", item.word, item.nature.toString(),
item.getFrequency())
);
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
List<MapResult> matches = matchStrategy.match(searchCtx.getQueryText(), terms, 0);
LOGGER.info("searchCtx:{},matches:{}", searchCtx, matches);
List<MapResult> matches = matchStrategy.match(queryContext.getQueryText(), terms, queryContext.getDomainId());
HanlpHelper.transLetterOriginal(matches);
LOGGER.info("queryContext:{},matches:{}", queryContext, matches);
convertTermsToSchemaMapInfo(matches, searchCtx.getMapInfo());
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo());
}
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap) {
@@ -59,11 +64,14 @@ public class HanlpSchemaMapper implements SchemaMapper {
SchemaElementMatch schemaElementMatch = new SchemaElementMatch();
schemaElementMatch.setElementType(elementType);
Integer elementID = WordNatureStrategyFactory.get(NatureType.getNatureType(nature))
.getElementID(nature);
BaseWordNature baseWordNature = WordNatureStrategyFactory.get(NatureType.getNatureType(nature));
Integer elementID = baseWordNature.getElementID(nature);
schemaElementMatch.setElementID(elementID);
Long frequency = baseWordNature.getFrequency(nature);
schemaElementMatch.setFrequency(frequency);
schemaElementMatch.setWord(mapResult.getName());
schemaElementMatch.setSimilarity(mapResult.getSimilarity());
schemaElementMatch.setDetectWord(mapResult.getDetectWord());
Map<Integer, List<SchemaElementMatch>> domainElementMatches = schemaMap.getDomainElementMatches();
List<SchemaElementMatch> schemaElementMatches = domainElementMatches.putIfAbsent(domain,

View File

@@ -13,18 +13,10 @@ import java.util.Map;
*/
public interface MatchStrategy {
/***
* match
* @param terms
* @return
*/
List<MapResult> match(String text, List<Term> terms, int retryCount);
List<MapResult> match(String text, List<Term> terms, Integer detectDomainId);
List<MapResult> match(String text, List<Term> terms, int retryCount, Integer detectDomainId);
Map<MatchText, List<MapResult>> matchWithMatchText(String text, List<Term> originals);
Map<MatchText, List<MapResult>> matchWithMatchText(String text, List<Term> originals, Integer detectDomainId);
/***
* exist dimension values

View File

@@ -40,35 +40,31 @@ public class QueryMatchStrategy implements MatchStrategy {
private Double dimensionValueThresholdConfig;
@Override
public List<MapResult> match(String text, List<Term> terms, int retryCount) {
return match(text, terms, retryCount, null);
}
@Override
public List<MapResult> match(String text, List<Term> terms, int retryCount, Integer detectDomainId) {
public List<MapResult> match(String text, List<Term> terms, Integer detectDomainId) {
if (CollectionUtils.isEmpty(terms) || StringUtils.isEmpty(text)) {
return null;
}
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
LOGGER.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectDomainId:{}", retryCount, terms,
regOffsetToLength, offsetList,
detectDomainId);
LOGGER.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectDomainId:{}", terms,
regOffsetToLength, offsetList, detectDomainId);
return detect(text, regOffsetToLength, offsetList, detectDomainId, retryCount);
return detect(text, regOffsetToLength, offsetList, detectDomainId);
}
@Override
public Map<MatchText, List<MapResult>> matchWithMatchText(String text, List<Term> originals) {
public Map<MatchText, List<MapResult>> matchWithMatchText(String text, List<Term> originals,
Integer detectDomainId) {
return null;
}
private List<MapResult> detect(String text, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
Integer detectDomainId, int retryCount) {
Integer detectDomainId) {
List<MapResult> results = Lists.newArrayList();
for (Integer index = 0; index <= text.length() - 1; ) {
@@ -79,7 +75,7 @@ public class QueryMatchStrategy implements MatchStrategy {
int offset = getStepOffset(offsetList, index);
i = getStepIndex(regOffsetToLength, i);
if (i <= text.length()) {
List<MapResult> mapResults = detectByStep(text, detectDomainId, index, i, offset, retryCount);
List<MapResult> mapResults = detectByStep(text, detectDomainId, index, i, offset);
mapResultRowSet.addAll(mapResults);
}
}
@@ -90,8 +86,7 @@ public class QueryMatchStrategy implements MatchStrategy {
return results;
}
private List<MapResult> detectByStep(String text, Integer detectClassId, Integer index, Integer i, int offset,
int retryCount) {
private List<MapResult> detectByStep(String text, Integer detectDomainId, Integer index, Integer i, int offset) {
String detectSegment = text.substring(index, i);
// step1. pre search
LinkedHashSet<MapResult> mapResults = Suggester.prefixSearch(detectSegment, oneDetectionMaxSize)
@@ -109,11 +104,11 @@ public class QueryMatchStrategy implements MatchStrategy {
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by classId
if (Objects.nonNull(detectClassId) && detectClassId > 0) {
if (Objects.nonNull(detectDomainId) && detectDomainId > 0) {
LOGGER.debug("detectDomainId:{}, before parseResults:{}", mapResults);
mapResults = mapResults.stream().map(entry -> {
List<String> natures = entry.getNatures().stream().filter(
nature -> nature.startsWith(NatureType.NATURE_SPILT + detectClassId) || (nature.startsWith(
nature -> nature.startsWith(NatureType.NATURE_SPILT + detectDomainId) || (nature.startsWith(
NatureType.NATURE_SPILT))
).collect(Collectors.toList());
entry.setNatures(natures);
@@ -123,8 +118,7 @@ public class QueryMatchStrategy implements MatchStrategy {
}
// step5. filter by similarity
mapResults = mapResults.stream()
.filter(term -> getSimilarity(detectSegment, term.getName()) >= getThresholdMatch(term.getNatures(),
retryCount))
.filter(term -> getSimilarity(detectSegment, term.getName()) >= getThresholdMatch(term.getNatures()))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new));
@@ -170,11 +164,11 @@ public class QueryMatchStrategy implements MatchStrategy {
return index;
}
private double getThresholdMatch(List<String> natures, int retryCount) {
private double getThresholdMatch(List<String> natures) {
if (existDimensionValues(natures)) {
return dimensionValueThresholdConfig;
}
return metricDimensionThresholdConfig - STEP * retryCount;
return metricDimensionThresholdConfig;
}
}

View File

@@ -23,20 +23,16 @@ public class SearchMatchStrategy implements MatchStrategy {
private static final int SEARCH_SIZE = 3;
@Override
public List<MapResult> match(String text, List<Term> terms, int retryCount) {
public List<MapResult> match(String text, List<Term> terms, Integer detectDomainId) {
return null;
}
@Override
public List<MapResult> match(String text, List<Term> terms, int retryCount, Integer detectDomainId) {
return null;
}
@Override
public Map<MatchText, List<MapResult>> matchWithMatchText(String text, List<Term> originals) {
public Map<MatchText, List<MapResult>> matchWithMatchText(String text, List<Term> originals,
Integer detectDomainId) {
Map<Integer, Integer> regOffsetToLength = originals.stream()
.filter(entry -> !entry.nature.toString().startsWith(NatureType.NATURE_SPILT))
@@ -70,7 +66,17 @@ public class SearchMatchStrategy implements MatchStrategy {
mapResults = mapResults.stream().filter(entry -> {
List<String> natures = entry.getNatures().stream()
.filter(nature -> !nature.endsWith(NatureType.ENTITY.getType()))
.collect(Collectors.toList());
.filter(nature -> {
if (Objects.isNull(detectDomainId) || detectDomainId <= 0) {
return true;
}
if (nature.startsWith(NatureType.NATURE_SPILT + detectDomainId)
&& nature.startsWith(NatureType.NATURE_SPILT)) {
return true;
}
return false;
}
).collect(Collectors.toList());
if (CollectionUtils.isEmpty(natures)) {
return false;
}

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.common.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.SchemaItem;
import com.tencent.supersonic.common.util.context.ContextUtils;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
@@ -21,7 +22,7 @@ import org.springframework.stereotype.Component;
public class AggregateSemanticParser implements SemanticParser {
private final Logger logger = LoggerFactory.getLogger(AggregateSemanticParser.class);
public static final Integer TOPN_LIMIT = 10;
public static final Integer TOPN_LIMIT = 1000;
private AggregateTypeResolver aggregateTypeResolver;
@@ -33,7 +34,7 @@ public class AggregateSemanticParser implements SemanticParser {
SemanticParseInfo semanticParse = queryContext.getParseInfo();
List<SchemaItem> metrics = semanticParse.getMetrics();
Set<SchemaItem> metrics = semanticParse.getMetrics();
semanticParse.setNativeQuery(getNativeQuery(aggregateType, queryContext));
@@ -47,12 +48,12 @@ public class AggregateSemanticParser implements SemanticParser {
/**
* query mode reset by the AggregateType
*
* @param searchCtx
* @param queryContext
* @param aggregateType
*/
private void resetQueryModeByAggregateType(QueryContextReq searchCtx, AggregateTypeEnum aggregateType) {
private void resetQueryModeByAggregateType(QueryContextReq queryContext, AggregateTypeEnum aggregateType) {
SemanticParseInfo parseInfo = searchCtx.getParseInfo();
SemanticParseInfo parseInfo = queryContext.getParseInfo();
String queryMode = parseInfo.getQueryMode();
if (MetricGroupBy.QUERY_MODE.equals(queryMode) || MetricGroupBy.QUERY_MODE.equals(queryMode)) {
if (AggregateTypeEnum.MAX.equals(aggregateType) || AggregateTypeEnum.MIN.equals(aggregateType)
@@ -63,7 +64,7 @@ public class AggregateSemanticParser implements SemanticParser {
}
}
if (MetricFilter.QUERY_MODE.equals(queryMode) || MetricCompare.QUERY_MODE.equals(queryMode)) {
if (aggregateTypeResolver.hasCompareIntentionalWords(searchCtx.getQueryText())) {
if (aggregateTypeResolver.hasCompareIntentionalWords(queryContext.getQueryText())) {
parseInfo.setQueryMode(MetricCompare.QUERY_MODE);
} else {
parseInfo.setQueryMode(MetricFilter.QUERY_MODE);
@@ -72,11 +73,11 @@ public class AggregateSemanticParser implements SemanticParser {
logger.info("queryMode mode [{}]->[{}]", queryMode, parseInfo.getQueryMode());
}
private boolean getNativeQuery(AggregateTypeEnum aggregateType, QueryContextReq searchCtx) {
private boolean getNativeQuery(AggregateTypeEnum aggregateType, QueryContextReq queryContext) {
if (AggregateTypeEnum.TOPN.equals(aggregateType)) {
return true;
}
return searchCtx.getParseInfo().getNativeQuery();
return queryContext.getParseInfo().getNativeQuery();
}

View File

@@ -9,8 +9,6 @@ import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.api.service.SemanticParser;
import com.tencent.supersonic.semantic.api.core.response.DimSchemaResp;
import com.tencent.supersonic.semantic.api.core.response.MetricSchemaResp;
import com.tencent.supersonic.chat.application.parser.resolver.DomainResolver;
import com.tencent.supersonic.chat.application.query.EntityDetail;
import com.tencent.supersonic.chat.application.query.EntityListFilter;
@@ -24,8 +22,10 @@ import com.tencent.supersonic.chat.domain.utils.DefaultSemanticInternalUtils;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.SchemaItem;
import com.tencent.supersonic.common.util.context.ContextUtils;
import java.util.ArrayList;
import com.tencent.supersonic.semantic.api.core.response.DimSchemaResp;
import com.tencent.supersonic.semantic.api.core.response.MetricSchemaResp;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@@ -104,33 +104,39 @@ public class DefaultMetricSemanticParser implements SemanticParser {
return primaryDimensions;
}
protected void addEntityDetailDimensionMetric(QueryContextReq searchCtx, ChatContext chatCtx) {
if (searchCtx.getParseInfo().getDomainId() > 0) {
ChatConfigRichInfo chaConfigRichDesc = defaultSemanticUtils.getChatConfigRichInfo(
searchCtx.getParseInfo().getDomainId());
protected void addEntityDetailDimensionMetric(QueryContextReq queryContext, ChatContext chatCtx) {
if (queryContext.getParseInfo().getDomainId() > 0) {
Long domainId = queryContext.getParseInfo().getDomainId();
ChatConfigRichInfo chaConfigRichDesc = defaultSemanticUtils.getChatConfigRichInfo(domainId);
if (chaConfigRichDesc != null) {
SemanticParseInfo semanticParseInfo = searchCtx.getParseInfo();
if (Objects.nonNull(semanticParseInfo) && CollectionUtils.isEmpty(semanticParseInfo.getDimensions())) {
List<SchemaItem> dimensions = new ArrayList<>();
List<SchemaItem> metrics = new ArrayList<>();
if (chaConfigRichDesc.getEntity() != null
&& chaConfigRichDesc.getEntity().getEntityInternalDetailDesc() != null) {
chaConfigRichDesc.getEntity().getEntityInternalDetailDesc().getMetricList().stream()
.forEach(m -> metrics.add(getMetric(m)));
chaConfigRichDesc.getEntity().getEntityInternalDetailDesc().getDimensionList().stream()
.forEach(m -> dimensions.add(getDimension(m)));
}
semanticParseInfo.setDimensions(dimensions);
semanticParseInfo.setMetrics(metrics);
if (chaConfigRichDesc.getEntity() == null
|| chaConfigRichDesc.getEntity().getEntityInternalDetailDesc() == null) {
return;
}
SemanticParseInfo semanticParseInfo = queryContext.getParseInfo();
Set<SchemaItem> metrics = new LinkedHashSet();
chaConfigRichDesc.getEntity().getEntityInternalDetailDesc().getMetricList().stream()
.forEach(m -> metrics.add(getMetric(m)));
semanticParseInfo.setMetrics(metrics);
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo()
.getMatchedElements(domainId.intValue());
if (CollectionUtils.isEmpty(schemaElementMatches) || schemaElementMatches.stream()
.filter(s -> SchemaElementType.DIMENSION.equals(s.getElementType())).count() <= 0) {
logger.info("addEntityDetailDimensionMetric catch");
Set<SchemaItem> dimensions = new LinkedHashSet();
chaConfigRichDesc.getEntity().getEntityInternalDetailDesc().getDimensionList().stream()
.forEach(m -> dimensions.add(getDimension(m)));
semanticParseInfo.setDimensions(dimensions);
}
}
}
}
protected void defaultQueryMode(QueryContextReq searchCtx, ChatContext chatCtx) {
SchemaMapInfo schemaMap = searchCtx.getMapInfo();
SemanticParseInfo parseInfo = searchCtx.getParseInfo();
protected void defaultQueryMode(QueryContextReq queryContext, ChatContext chatCtx) {
SchemaMapInfo schemaMap = queryContext.getMapInfo();
SemanticParseInfo parseInfo = queryContext.getParseInfo();
if (StringUtils.isEmpty(parseInfo.getQueryMode())) {
if (chatCtx.getParseInfo() != null && chatCtx.getParseInfo().getDomainId() > 0) {
//
@@ -182,12 +188,12 @@ public class DefaultMetricSemanticParser implements SemanticParser {
}
private void fillDateDomain(ChatContext chatCtx, QueryContextReq searchCtx) {
SemanticParseInfo parseInfo = searchCtx.getParseInfo();
private void fillDateDomain(ChatContext chatCtx, QueryContextReq queryContext) {
SemanticParseInfo parseInfo = queryContext.getParseInfo();
if (parseInfo == null || parseInfo.getDateInfo() == null) {
boolean isUpdateTime = false;
if (selectStrategy.isDomainSwitch(chatCtx, searchCtx)) {
if (selectStrategy.isDomainSwitch(chatCtx, queryContext)) {
isUpdateTime = true;
}
if (chatCtx.getParseInfo() == null
@@ -212,7 +218,7 @@ public class DefaultMetricSemanticParser implements SemanticParser {
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) && CollectionUtils.isEmpty(
semanticParseInfo.getDimensions())) {
List<SchemaItem> metrics = new ArrayList<>();
Set<SchemaItem> metrics = new LinkedHashSet();
chaConfigRichDesc.getDefaultMetrics().stream().forEach(metric -> {
SchemaItem metricTmp = new SchemaItem();
metricTmp.setId(metric.getMetricId());

View File

@@ -25,6 +25,7 @@ import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.support.SpringFactoriesLoader;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -32,10 +33,15 @@ import org.springframework.util.CollectionUtils;
public class DomainSemanticParser implements SemanticParser {
private final Logger logger = LoggerFactory.getLogger(DomainSemanticParser.class);
private DomainResolver domainResolver;
private List<DomainResolver> domainResolverList;
private SemanticQueryResolver semanticQueryResolver;
public DomainSemanticParser() {
domainResolverList = SpringFactoriesLoader.loadFactories(DomainResolver.class,
Thread.currentThread().getContextClassLoader());
}
@Override
public boolean parse(QueryContextReq queryContext, ChatContext chatCtx) {
DomainInfos domainInfosDb = SchemaInfoConverter.convert(
@@ -45,7 +51,7 @@ public class DomainSemanticParser implements SemanticParser {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
SemanticParseInfo parseInfo = queryContext.getParseInfo();
domainResolver = ContextUtils.getBean(DomainResolver.class);
//domainResolver = ContextUtils.getBean(DomainResolver.class);
semanticQueryResolver = ContextUtils.getBean(SemanticQueryResolver.class);
Map<Integer, SemanticQuery> domainSemanticQuery = new HashMap<>();
@@ -72,15 +78,19 @@ public class DomainSemanticParser implements SemanticParser {
}
} else if (domainSemanticQuery.size() > 1) {
// will choose one by the domain select
Integer domainId = domainResolver.resolve(domainSemanticQuery, queryContext, chatCtx, mapInfo);
if (domainId > 0) {
Map.Entry<Integer, SemanticQuery> match = domainSemanticQuery.entrySet().stream()
.filter(entry -> entry.getKey().equals(domainId)).findFirst().orElse(null);
logger.info("select by selectStrategy [{}:{}]", domainId, match.getValue());
parseInfo.setDomainId(Long.valueOf(match.getKey()));
parseInfo.setDomainName(domainToName.get(Integer.valueOf(match.getKey())));
parseInfo.setQueryMode(match.getValue().getQueryMode());
return false;
Optional<Integer> domainId = domainResolverList.stream()
.map(domainResolver -> domainResolver.resolve(domainSemanticQuery, queryContext, chatCtx, mapInfo))
.filter(d -> d > 0).findFirst();
if (domainId.isPresent() && domainId.get() > 0) {
for (Map.Entry<Integer, SemanticQuery> match : domainSemanticQuery.entrySet()) {
if (match.getKey().equals(domainId.get())) {
logger.info("select by selectStrategy [{}:{}]", domainId.get(), match.getValue());
parseInfo.setDomainId(Long.valueOf(match.getKey()));
parseInfo.setDomainName(domainToName.get(Integer.valueOf(match.getKey())));
parseInfo.setQueryMode(match.getValue().getQueryMode());
return false;
}
}
}
}
// Round 2: no domains can be found yet, count in chat context

View File

@@ -19,6 +19,7 @@ import com.tencent.supersonic.common.pojo.SchemaItem;
import com.tencent.supersonic.common.util.context.ContextUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@@ -43,7 +44,7 @@ public class ListFilterParser implements SemanticParser {
this.fillDateEntityFilter(queryContext.getParseInfo());
this.addEntityDetailAndOrderByMetric(queryContext, chatCtx);
this.dealNativeQuery(queryContext, true);
return false;
return true;
}
@@ -56,15 +57,15 @@ public class ListFilterParser implements SemanticParser {
semanticParseInfo.setDateInfo(dateInfo);
}
private void addEntityDetailAndOrderByMetric(QueryContextReq searchCtx, ChatContext chatCtx) {
if (searchCtx.getParseInfo().getDomainId() > 0L) {
private void addEntityDetailAndOrderByMetric(QueryContextReq queryContext, ChatContext chatCtx) {
if (queryContext.getParseInfo().getDomainId() > 0L) {
ChatConfigRichInfo chaConfigRichDesc = defaultSemanticUtils.getChatConfigRichInfo(
searchCtx.getParseInfo().getDomainId());
queryContext.getParseInfo().getDomainId());
if (chaConfigRichDesc != null) {
SemanticParseInfo semanticParseInfo = searchCtx.getParseInfo();
List<SchemaItem> dimensions = new ArrayList();
SemanticParseInfo semanticParseInfo = queryContext.getParseInfo();
Set<SchemaItem> dimensions = new LinkedHashSet();
Set<String> primaryDimensions = this.addPrimaryDimension(chaConfigRichDesc.getEntity(), dimensions);
List<SchemaItem> metrics = new ArrayList();
Set<SchemaItem> metrics = new LinkedHashSet();
if (chaConfigRichDesc.getEntity() != null
&& chaConfigRichDesc.getEntity().getEntityInternalDetailDesc() != null) {
chaConfigRichDesc.getEntity().getEntityInternalDetailDesc().getMetricList().stream()
@@ -76,7 +77,7 @@ public class ListFilterParser implements SemanticParser {
semanticParseInfo.setDimensions(dimensions);
semanticParseInfo.setMetrics(metrics);
List<Order> orders = new ArrayList();
Set<Order> orders = new LinkedHashSet();
if (chaConfigRichDesc.getEntity() != null
&& chaConfigRichDesc.getEntity().getEntityInternalDetailDesc() != null) {
chaConfigRichDesc.getEntity().getEntityInternalDetailDesc().getMetricList().stream()
@@ -89,7 +90,7 @@ public class ListFilterParser implements SemanticParser {
}
private Set<String> addPrimaryDimension(EntityRichInfo entity, List<SchemaItem> dimensions) {
private Set<String> addPrimaryDimension(EntityRichInfo entity, Set<SchemaItem> dimensions) {
Set<String> primaryDimensions = new HashSet();
if (!Objects.isNull(entity) && !CollectionUtils.isEmpty(entity.getEntityIds())) {
entity.getEntityIds().stream().forEach((dimSchemaDesc) -> {
@@ -120,9 +121,9 @@ public class ListFilterParser implements SemanticParser {
return queryMeta;
}
private void dealNativeQuery(QueryContextReq searchCtx, boolean isNativeQuery) {
if (Objects.nonNull(searchCtx) && Objects.nonNull(searchCtx.getParseInfo())) {
searchCtx.getParseInfo().setNativeQuery(isNativeQuery);
private void dealNativeQuery(QueryContextReq queryContext, boolean isNativeQuery) {
if (Objects.nonNull(queryContext) && Objects.nonNull(queryContext.getParseInfo())) {
queryContext.getParseInfo().setNativeQuery(isNativeQuery);
}
}

View File

@@ -0,0 +1,131 @@
package com.tencent.supersonic.chat.application.parser.resolver;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementCount;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.api.service.SemanticQuery;
import com.tencent.supersonic.chat.domain.utils.ContextHelper;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public abstract class BaseDomainResolver implements DomainResolver {
@Override
public boolean isDomainSwitch(ChatContext chatCtx, QueryContextReq searchCtx) {
Long contextDomain = chatCtx.getParseInfo().getDomainId();
Long currentDomain = searchCtx.getParseInfo().getDomainId();
boolean noSwitch = currentDomain == null || contextDomain == null || contextDomain.equals(currentDomain);
log.info("ChatContext isDomainSwitch [{}]", !noSwitch);
return !noSwitch;
}
public abstract Integer selectDomain(Map<Integer, SemanticQuery> domainQueryModes, QueryContextReq searchCtx,
ChatContext chatCtx, SchemaMapInfo schemaMap);
@Override
public Integer resolve(Map<Integer, SemanticQuery> domainQueryModes, QueryContextReq searchCtx,
ChatContext chatCtx, SchemaMapInfo schemaMap) {
Integer selectDomain = selectDomain(domainQueryModes, searchCtx, chatCtx, schemaMap);
if (selectDomain > 0) {
log.info("selectDomain {} ", selectDomain);
return selectDomain;
}
// get the max SchemaElementType number
return selectDomainBySchemaElementCount(domainQueryModes, schemaMap);
}
protected static Integer selectDomainBySchemaElementCount(Map<Integer, SemanticQuery> domainQueryModes,
SchemaMapInfo schemaMap) {
Map<Integer, SchemaElementCount> domainTypeMap = getDomainTypeMap(schemaMap);
if (domainTypeMap.size() == 1) {
Integer domainSelect = domainTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (domainQueryModes.containsKey(domainSelect)) {
log.info("selectDomain from domainTypeMap not order [{}]", domainSelect);
return domainSelect;
}
} else {
Map.Entry<Integer, SchemaElementCount> maxDomain = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()))
.sorted(ContextHelper.DomainStatComparator).findFirst().orElse(null);
if (maxDomain != null) {
log.info("selectDomain from domainTypeMap order [{}]", maxDomain.getKey());
return maxDomain.getKey();
}
}
return 0;
}
/**
* to check can switch domain if context exit domain
*
* @return false will use context domain, true will use other domain , maybe include context domain
*/
protected static boolean isAllowSwitch(Map<Integer, SemanticQuery> domainQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryContextReq searchCtx, Integer domainId) {
if (!Objects.nonNull(domainId) || domainId <= 0) {
return true;
}
// except content domain, calculate the number of types for each domain, if numbers<=1 will not switch
Map<Integer, SchemaElementCount> domainTypeMap = getDomainTypeMap(schemaMap);
log.info("isAllowSwitch domainTypeMap [{}]", domainTypeMap);
long otherDomainTypeNumBigOneCount = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(domainId))
.filter(entry -> entry.getValue().getCount() > 1).count();
if (otherDomainTypeNumBigOneCount >= 1) {
return true;
}
// if query text only contain time , will not switch
if (searchCtx.getQueryText() != null && searchCtx.getParseInfo().getDateInfo() != null) {
if (searchCtx.getParseInfo().getDateInfo().getText() != null) {
if (searchCtx.getParseInfo().getDateInfo().getText().equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
searchCtx.getParseInfo().getDateInfo());
return false;
}
}
}
// if context domain not in schemaMap , will switch
if (schemaMap.getMatchedElements(domainId) == null || schemaMap.getMatchedElements(domainId).size() <= 0) {
log.info("domainId not in schemaMap ");
return true;
}
// other will not switch
return false;
}
protected static Map<Integer, SchemaElementCount> getDomainTypeMap(SchemaMapInfo schemaMap) {
Map<Integer, SchemaElementCount> domainCount = new HashMap<>();
for (Map.Entry<Integer, List<SchemaElementMatch>> entry : schemaMap.getDomainElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!domainCount.containsKey(entry.getKey())) {
domainCount.put(entry.getKey(), new SchemaElementCount());
}
SchemaElementCount schemaElementCount = domainCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(schemaElementMatch.getElementType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted(ContextHelper.schemaElementMatchComparatorBySimilarity).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
schemaElementCount.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
schemaElementCount.setCount(schemaElementTypes.size());
}
}
return domainCount;
}
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.api.service.SemanticQuery;
import com.tencent.supersonic.chat.domain.utils.ContextHelper;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Primary;
@@ -22,121 +23,28 @@ import java.util.Set;
import java.util.stream.Collectors;
@Service("heuristicDomainResolver")
@Primary
public class HeuristicDomainResolver implements DomainResolver {
private static final Logger LOGGER = LoggerFactory.getLogger(HeuristicDomainResolver.class);
@Service("DomainResolver")
@Slf4j
public class HeuristicDomainResolver extends BaseDomainResolver {
@Override
public Integer resolve(Map<Integer, SemanticQuery> domainQueryModes, QueryContextReq searchCtx,
ChatContext chatCtx, SchemaMapInfo schemaMap) {
public Integer selectDomain(Map<Integer, SemanticQuery> domainQueryModes, QueryContextReq searchCtx,
ChatContext chatCtx,
SchemaMapInfo schemaMap) {
// if QueryContext has domainId and in domainQueryModes
if (domainQueryModes.containsKey(searchCtx.getDomainId())) {
LOGGER.info("selectDomain from QueryContext [{}]", searchCtx.getDomainId());
log.info("selectDomain from QueryContext [{}]", searchCtx.getDomainId());
return searchCtx.getDomainId();
}
// if ChatContext has domainId and in domainQueryModes
if (chatCtx.getParseInfo().getDomainId() > 0) {
Integer domainId = Integer.valueOf(chatCtx.getParseInfo().getDomainId().intValue());
if (!isAllowSwitch(domainQueryModes, schemaMap, chatCtx, searchCtx, domainId)) {
LOGGER.info("selectDomain from ChatContext [{}]", domainId);
log.info("selectDomain from ChatContext [{}]", domainId);
return domainId;
}
}
// get the max SchemaElementType number
Map<Integer, SchemaElementCount> domainTypeMap = getDomainTypeMap(schemaMap);
if (domainTypeMap.size() == 1) {
Integer domainSelect = domainTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (domainQueryModes.containsKey(domainSelect)) {
LOGGER.info("selectDomain from domainTypeMap not order [{}]", domainSelect);
return domainSelect;
}
} else {
Map.Entry<Integer, SchemaElementCount> maxDomain = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()))
.sorted(ContextHelper.DomainStatComparator).findFirst().orElse(null);
if (maxDomain != null) {
LOGGER.info("selectDomain from domainTypeMap need order [{}]", maxDomain.getKey());
return maxDomain.getKey();
}
}
// bad case , here will not reach , default 0
LOGGER.error("selectDomain not found ");
// default 0
return 0;
}
@Override
public boolean isDomainSwitch(ChatContext chatCtx, QueryContextReq searchCtx) {
Long contextDomain = chatCtx.getParseInfo().getDomainId();
Long currentDomain = searchCtx.getParseInfo().getDomainId();
boolean noSwitch = currentDomain == null || contextDomain == null || contextDomain.equals(currentDomain);
LOGGER.info("ChatContext isDomainSwitch [{}]", !noSwitch);
return !noSwitch;
}
/**
* to check can switch domain if context exit domain
*
* @return false will use context domain, true will use other domain , maybe include context domain
*/
private static boolean isAllowSwitch(Map<Integer, SemanticQuery> domainQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryContextReq searchCtx, Integer domainId) {
if (!Objects.nonNull(domainId) || domainId <= 0) {
return true;
}
// except content domain, calculate the number of types for each domain, if numbers<=1 will not switch
Map<Integer, SchemaElementCount> domainTypeMap = getDomainTypeMap(schemaMap);
LOGGER.info("isAllowSwitch domainTypeMap [{}]", domainTypeMap);
long otherDomainTypeNumBigOneCount = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(domainId))
.filter(entry -> entry.getValue().getCount() > 1).count();
if (otherDomainTypeNumBigOneCount >= 1) {
return true;
}
// if query text only contain time , will not switch
if (searchCtx.getQueryText() != null && searchCtx.getParseInfo().getDateInfo() != null) {
if (searchCtx.getParseInfo().getDateInfo().getText() != null) {
if (searchCtx.getParseInfo().getDateInfo().getText().equalsIgnoreCase(searchCtx.getQueryText())) {
LOGGER.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
searchCtx.getParseInfo().getDateInfo());
return false;
}
}
}
// if context domain not in schemaMap , will switch
if (schemaMap.getMatchedElements(domainId) == null || schemaMap.getMatchedElements(domainId).size() <= 0) {
LOGGER.info("domainId not in schemaMap ");
return true;
}
// other will not switch
return false;
}
private static Map<Integer, SchemaElementCount> getDomainTypeMap(SchemaMapInfo schemaMap) {
Map<Integer, SchemaElementCount> domainCount = new HashMap<>();
for (Map.Entry<Integer, List<SchemaElementMatch>> entry : schemaMap.getDomainElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!domainCount.containsKey(entry.getKey())) {
domainCount.put(entry.getKey(), new SchemaElementCount());
}
SchemaElementCount schemaElementCount = domainCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(schemaElementMatch.getElementType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted(ContextHelper.schemaElementMatchComparatorBySimilarity).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
schemaElementCount.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
schemaElementCount.setCount(schemaElementTypes.size());
}
}
return domainCount;
}
}

View File

@@ -22,7 +22,7 @@ public class RegexAggregateTypeResolver implements AggregateTypeResolver {
aggregateRegexMap.put(AggregateTypeEnum.MAX, Pattern.compile("(?i)(最大值|最大|max|峰值|最高)"));
aggregateRegexMap.put(AggregateTypeEnum.MIN, Pattern.compile("(?i)(最小值|最小|min|最低)"));
aggregateRegexMap.put(AggregateTypeEnum.SUM, Pattern.compile("(?i)(汇总|总和|sum)"));
aggregateRegexMap.put(AggregateTypeEnum.AVG, Pattern.compile("(?i)(平均值|平均|avg)"));
aggregateRegexMap.put(AggregateTypeEnum.AVG, Pattern.compile("(?i)(平均值|日均|平均|avg)"));
aggregateRegexMap.put(AggregateTypeEnum.TOPN, Pattern.compile("(?i)(top)"));
aggregateRegexMap.put(AggregateTypeEnum.DISTINCT, Pattern.compile("(?i)(uv)"));
aggregateRegexMap.put(AggregateTypeEnum.COUNT, Pattern.compile("(?i)(总数|pv)"));

View File

@@ -11,6 +11,7 @@ import org.springframework.stereotype.Service;
public class EntityListFilter extends BaseSemanticQuery {
public static String QUERY_MODE = "ENTITY_LIST_FILTER";
private static Long entityListLimit = 200L;
public EntityListFilter() {
queryModeOption.setAggregation(QueryModeElementOption.unused());
@@ -42,6 +43,7 @@ public class EntityListFilter extends BaseSemanticQuery {
public SemanticParseInfo getContext(ChatContext chatCtx, QueryContextReq queryCtx) {
SemanticParseInfo semanticParseInfo = queryCtx.getParseInfo();
ContextHelper.addIfEmpty(chatCtx.getParseInfo().getDimensionFilters(), semanticParseInfo.getDimensionFilters());
semanticParseInfo.setLimit(entityListLimit);
return semanticParseInfo;
}

View File

@@ -1,13 +1,24 @@
package com.tencent.supersonic.chat.application.query;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.Filter;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.domain.pojo.chat.SchemaElementOption;
import com.tencent.supersonic.chat.domain.utils.ContextHelper;
import com.tencent.supersonic.common.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class MetricCompare extends BaseSemanticQuery {
public static String QUERY_MODE = "METRIC_COMPARE";
@@ -48,7 +59,61 @@ public class MetricCompare extends BaseSemanticQuery {
SemanticParseInfo semanticParseInfo = queryCtx.getParseInfo();
ContextHelper.updateTimeIfEmpty(chatCtx.getParseInfo(), semanticParseInfo);
ContextHelper.addIfEmpty(chatCtx.getParseInfo().getMetrics(), semanticParseInfo.getMetrics());
ContextHelper.appendList(chatCtx.getParseInfo().getDimensionFilters(), semanticParseInfo.getDimensionFilters());
mergeAppend(chatCtx.getParseInfo().getDimensionFilters(), semanticParseInfo.getDimensionFilters());
addCompareDimension(semanticParseInfo);
return semanticParseInfo;
}
private void addCompareDimension(SemanticParseInfo semanticParseInfo) {
if (!semanticParseInfo.getDimensionFilters().isEmpty()) {
Set<String> dimensions = semanticParseInfo.getDimensions().stream().map(d -> d.getBizName()).collect(
Collectors.toSet());
log.info("addCompareDimension before [{}]", dimensions);
semanticParseInfo.getDimensionFilters().stream().filter(d -> d.getOperator().equals(FilterOperatorEnum.IN))
.forEach(
d -> {
if (!dimensions.contains(d.getBizName())) {
SchemaItem schemaItem = new SchemaItem();
schemaItem.setBizName(d.getBizName());
schemaItem.setId(d.getElementID());
semanticParseInfo.getDimensions().add(schemaItem);
dimensions.add(d.getBizName());
}
}
);
log.info("addCompareDimension after [{}]", dimensions);
}
}
private void mergeAppend(Set<Filter> from, Set<Filter> to) {
if (!from.isEmpty()) {
for (Filter filter : from) {
if (FilterOperatorEnum.EQUALS.equals(filter.getOperator()) || FilterOperatorEnum.IN.equals(
filter.getOperator())) {
Optional<Filter> toAdd = to.stream()
.filter(t -> t.getBizName().equalsIgnoreCase(filter.getBizName())).findFirst();
if (toAdd.isPresent()) {
if (FilterOperatorEnum.EQUALS.equals(toAdd.get().getOperator()) || FilterOperatorEnum.IN.equals(
toAdd.get().getOperator())) {
List<Object> vals = new ArrayList<>();
if (toAdd.get().getOperator().equals(FilterOperatorEnum.IN)) {
vals.addAll((List<Object>) (toAdd.get().getValue()));
} else {
vals.add(toAdd.get().getValue());
}
if (filter.getOperator().equals(FilterOperatorEnum.IN)) {
vals.addAll((List<Object>) (filter.getValue()));
} else {
vals.add(filter.getValue());
}
toAdd.get().setValue(vals);
toAdd.get().setOperator(FilterOperatorEnum.IN);
continue;
}
}
}
to.add(filter);
}
}
}
}

View File

@@ -2,21 +2,25 @@ package com.tencent.supersonic.chat.domain.pojo.chat;
import com.tencent.supersonic.chat.api.pojo.Filter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.SchemaItem;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
@Data
public class QueryData {
Long domainId = 0L;
List<SchemaItem> metrics = new ArrayList<>();
List<SchemaItem> dimensions = new ArrayList<>();
List<Filter> filters = new ArrayList<>();
private List<Order> orders = new ArrayList<>();
Set<SchemaItem> metrics = new HashSet<>();
Set<SchemaItem> dimensions = new HashSet<>();
Set<Filter> dimensionFilters = new HashSet<>();
Set<Filter> metricFilters = new HashSet<>();
private Set<Order> orders = new HashSet<>();
private DateConf dateInfo;
private Long limit;
private Boolean nativeQuery = false;

View File

@@ -4,6 +4,7 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.api.response.QueryResultResp;
import com.tencent.supersonic.chat.domain.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.domain.pojo.chat.ChatQueryVO;
import com.tencent.supersonic.chat.domain.pojo.chat.PageQueryInfoReq;
@@ -13,4 +14,7 @@ public interface ChatQueryRepository {
void createChatQuery(QueryResultResp queryResponse, QueryContextReq queryContext, ChatContext chatCtx);
ChatQueryDO getLastChatQuery(long chatId);
int updateChatQuery(ChatQueryDO chatQueryDO);
}

View File

@@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.request.QueryContextReq;
import com.tencent.supersonic.chat.api.response.QueryResultResp;
import com.tencent.supersonic.chat.domain.dataobject.ChatDO;
import com.tencent.supersonic.chat.domain.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.domain.pojo.chat.ChatQueryVO;
import com.tencent.supersonic.chat.domain.pojo.chat.PageQueryInfoReq;
import java.util.List;
@@ -40,4 +41,8 @@ public interface ChatService {
PageInfo<ChatQueryVO> queryInfo(PageQueryInfoReq pageQueryInfoCommend, long chatId);
public void addQuery(QueryResultResp queryResponse, QueryContextReq queryContext, ChatContext chatCtx);
public ChatQueryDO getLastQuery(long chatId);
public int updateQuery(ChatQueryDO chatQueryDO);
}

View File

@@ -10,6 +10,7 @@ import com.tencent.supersonic.chat.api.service.SemanticQuery;
import com.tencent.supersonic.semantic.api.core.response.DimSchemaResp;
import com.tencent.supersonic.chat.domain.pojo.config.ChatConfigRichInfo;
import com.tencent.supersonic.common.pojo.SchemaItem;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
@@ -71,7 +72,7 @@ public class ContextHelper {
* @param from
* @param to
*/
public static void addIfEmpty(List from, List to) {
public static void addIfEmpty(Set from, Set to) {
if (to.isEmpty() && !from.isEmpty()) {
to.addAll(from);
}
@@ -82,7 +83,7 @@ public class ContextHelper {
* @param from
* @param to
*/
public static void appendList(List from, List to) {
public static void appendList(Set from, Set to) {
if (!from.isEmpty()) {
to.addAll(from);
}
@@ -94,7 +95,7 @@ public class ContextHelper {
* @param from
* @param to
*/
public static void updateList(List from, List to) {
public static void updateList(Set from, Set to) {
if (!from.isEmpty()) {
to.clear();
to.addAll(from);

View File

@@ -21,9 +21,14 @@ import com.tencent.supersonic.semantic.api.core.response.DomainSchemaResp;
import com.tencent.supersonic.semantic.api.core.response.MetricSchemaResp;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
public class SchemaInfoConverter {
@@ -40,6 +45,7 @@ public class SchemaInfoConverter {
queryStructCmd.setDateInfo(parseInfo.getDateInfo());
List<Filter> dimensionFilters = parseInfo.getDimensionFilters().stream()
.filter(chatFilter -> Strings.isNotEmpty(chatFilter.getBizName()))
.map(chatFilter -> new Filter(chatFilter.getBizName(), chatFilter.getOperator(), chatFilter.getValue()))
.collect(Collectors.toList());
queryStructCmd.setDimensionFilters(dimensionFilters);
@@ -54,12 +60,13 @@ public class SchemaInfoConverter {
.collect(Collectors.toList());
queryStructCmd.setGroups(dimensions);
queryStructCmd.setLimit(parseInfo.getLimit());
queryStructCmd.setOrders(getOrder(parseInfo.getOrders(), parseInfo.getAggType(), parseInfo.getMetrics()));
Set<Order> order = getOrder(parseInfo.getOrders(), parseInfo.getAggType(), parseInfo.getMetrics());
queryStructCmd.setOrders(new ArrayList<>(order));
queryStructCmd.setAggregators(getAggregatorByMetric(parseInfo.getMetrics(), parseInfo.getAggType()));
return queryStructCmd;
}
private static List<Aggregator> getAggregatorByMetric(List<SchemaItem> metrics, AggregateTypeEnum aggregateType) {
private static List<Aggregator> getAggregatorByMetric(Set<SchemaItem> metrics, AggregateTypeEnum aggregateType) {
List<Aggregator> aggregators = new ArrayList<>();
String agg = (aggregateType == null || aggregateType.equals(AggregateTypeEnum.NONE)) ? ""
: aggregateType.name();
@@ -157,11 +164,11 @@ public class SchemaInfoConverter {
return result;
}
public static List<Order> getOrder(List<Order> parseOrder, AggregateTypeEnum aggregator, List<SchemaItem> metrics) {
public static Set<Order> getOrder(Set<Order> parseOrder, AggregateTypeEnum aggregator, Set<SchemaItem> metrics) {
if (!CollectionUtils.isEmpty(parseOrder)) {
return parseOrder;
}
List<Order> orders = new ArrayList<>();
Set<Order> orders = new LinkedHashSet();
if (CollectionUtils.isEmpty(metrics)) {
return orders;
}

View File

@@ -14,11 +14,13 @@ import com.tencent.supersonic.chat.domain.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.infrastructure.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.common.util.json.JsonUtil;
import com.tencent.supersonic.common.util.mybatis.PageUtils;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import org.springframework.util.CollectionUtils;
@Repository
@Primary
@@ -69,4 +71,24 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
Long queryId = Long.valueOf(chatQueryDOMapper.insert(chatQueryDO));
queryResponse.setQueryId(queryId);
}
@Override
public ChatQueryDO getLastChatQuery(long chatId) {
ChatQueryDOExample example = new ChatQueryDOExample();
example.setOrderByClause("question_id desc");
example.setLimitEnd(1);
example.setLimitStart(0);
Criteria criteria = example.createCriteria();
criteria.andChatIdEqualTo(chatId);
List<ChatQueryDO> chatQueryDOS = chatQueryDOMapper.selectByExampleWithBLOBs(example);
if (!CollectionUtils.isEmpty(chatQueryDOS)) {
return chatQueryDOS.get(0);
}
return null;
}
@Override
public int updateChatQuery(ChatQueryDO chatQueryDO) {
return chatQueryDOMapper.updateByPrimaryKeyWithBLOBs(chatQueryDO);
}
}

View File

@@ -11,11 +11,11 @@ class HanlpSchemaMapperTest extends ContextTest {
@Test
void map() {
QueryContextReq searchCtx = new QueryContextReq();
searchCtx.setChatId(1);
searchCtx.setDomainId(2);
searchCtx.setQueryText("supersonic按部门访问次数");
QueryContextReq queryContext = new QueryContextReq();
queryContext.setChatId(1);
queryContext.setDomainId(2);
queryContext.setQueryText("supersonic按部门访问次数");
HanlpSchemaMapper hanlpSchemaMapper = new HanlpSchemaMapper();
hanlpSchemaMapper.map(searchCtx);
hanlpSchemaMapper.map(queryContext);
}
}

View File

@@ -13,14 +13,14 @@ class TimeSemanticParserTest {
void parse() {
TimeSemanticParser timeSemanticParser = new TimeSemanticParser();
QueryContextReq searchCtx = new QueryContextReq();
QueryContextReq queryContext = new QueryContextReq();
ChatContext chatCtx = new ChatContext();
SchemaMapInfo schemaMap = new SchemaMapInfo();
searchCtx.setQueryText("supersonic最近30天访问次数");
queryContext.setQueryText("supersonic最近30天访问次数");
boolean parse = timeSemanticParser.parse(searchCtx, chatCtx);
boolean parse = timeSemanticParser.parse(queryContext, chatCtx);
DateConf dateInfo = searchCtx.getParseInfo().getDateInfo();
DateConf dateInfo = queryContext.getParseInfo().getDateInfo();
}

View File

@@ -8,7 +8,9 @@ import com.tencent.supersonic.common.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.SchemaItem;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import lombok.Data;
public class SemanticParseObjectHelper {
@@ -29,9 +31,9 @@ public class SemanticParseObjectHelper {
private static SemanticParseInfo getSemanticParseInfo(SemanticParseJson semanticParseJson) {
Long domain = semanticParseJson.getDomain();
List<SchemaItem> dimensionList = new ArrayList<>();
List<SchemaItem> metricList = new ArrayList<>();
List<Filter> chatFilters = new ArrayList<>();
Set<SchemaItem> dimensionList = new LinkedHashSet();
Set<SchemaItem> metricList = new LinkedHashSet();
Set<Filter> chatFilters = new LinkedHashSet();
if (semanticParseJson.getFilter() != null && semanticParseJson.getFilter().size() > 0) {
for (List<String> filter : semanticParseJson.getFilter()) {