mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(chat) Support semantic understanding, optimize the overall code of the mapper. (#321)
This commit is contained in:
@@ -42,4 +42,18 @@ public class OptimizationConfig {
|
|||||||
@Value("${user.s2ql.switch:false}")
|
@Value("${user.s2ql.switch:false}")
|
||||||
private boolean useS2qlSwitch;
|
private boolean useS2qlSwitch;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.word.min:2}")
|
||||||
|
private int embeddingMapperWordMin;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.number:10}")
|
||||||
|
private int embeddingMapperNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.sum.number:10}")
|
||||||
|
private int embeddingMapperSumNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.distance.threshold:0.3}")
|
||||||
|
private Double embeddingMapperDistanceThreshold;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,26 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* base Mapper
|
* base Mapper
|
||||||
@@ -19,12 +32,17 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
public void map(QueryContext queryContext) {
|
public void map(QueryContext queryContext) {
|
||||||
|
|
||||||
String simpleName = this.getClass().getSimpleName();
|
String simpleName = this.getClass().getSimpleName();
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo());
|
||||||
|
|
||||||
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo());
|
try {
|
||||||
|
work(queryContext);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("work error", e);
|
||||||
|
}
|
||||||
|
|
||||||
work(queryContext);
|
long cost = System.currentTimeMillis() - startTime;
|
||||||
|
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo());
|
||||||
log.debug("after {},mapInfo:{}", simpleName, queryContext.getMapInfo());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract void work(QueryContext queryContext);
|
public abstract void work(QueryContext queryContext);
|
||||||
@@ -38,4 +56,63 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
schemaElementMatches.add(schemaElementMatch);
|
schemaElementMatches.add(schemaElementMatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Set<Long> getModelIds(QueryContext queryContext) {
|
||||||
|
return ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest());
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
logTerms(terms);
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||||
|
terms = terms.stream().filter(term -> {
|
||||||
|
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||||
|
if (Objects.nonNull(modelId)) {
|
||||||
|
return detectModelIds.contains(modelId);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}).collect(Collectors.toList());
|
||||||
|
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||||
|
logTerms(terms);
|
||||||
|
}
|
||||||
|
return terms;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void logTerms(List<Term> terms) {
|
||||||
|
if (CollectionUtils.isEmpty(terms)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (Term term : terms) {
|
||||||
|
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
|
||||||
|
SchemaElement element = new SchemaElement();
|
||||||
|
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||||
|
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||||
|
if (Objects.isNull(modelSchema)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||||
|
if (Objects.isNull(elementDb)) {
|
||||||
|
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
BeanUtils.copyProperties(elementDb, element);
|
||||||
|
element.setAlias(getAlias(elementDb));
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getAlias(SchemaElement element) {
|
||||||
|
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||||
|
return element.getAlias();
|
||||||
|
}
|
||||||
|
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
|
||||||
|
element.getName())) {
|
||||||
|
return element.getAlias().stream()
|
||||||
|
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return element.getAlias();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* match strategy implement
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<MatchText, List<T>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
String text = queryReq.getQueryText();
|
||||||
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||||
|
|
||||||
|
List<T> detects = detect(queryReq, terms, detectModelIds);
|
||||||
|
Map<MatchText, List<T>> result = new HashMap<>();
|
||||||
|
|
||||||
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<T> detect(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||||
|
String text = queryReq.getQueryText();
|
||||||
|
Set<T> results = new HashSet<>();
|
||||||
|
|
||||||
|
for (Integer index = 0; index <= text.length() - 1; ) {
|
||||||
|
|
||||||
|
for (Integer i = index; i <= text.length(); ) {
|
||||||
|
int offset = mapperHelper.getStepOffset(terms, index);
|
||||||
|
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
||||||
|
if (i <= text.length()) {
|
||||||
|
detectByStep(queryReq, results, detectModelIds, index, i, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
|
}
|
||||||
|
return new ArrayList<>(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
||||||
|
return terms.stream().sorted(Comparator.comparing(Term::length))
|
||||||
|
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||||
|
if (CollectionUtils.isEmpty(oneRoundResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (T oneRoundResult : oneRoundResults) {
|
||||||
|
if (existResults.contains(oneRoundResult)) {
|
||||||
|
boolean isDeleted = existResults.removeIf(
|
||||||
|
existResult -> {
|
||||||
|
boolean delete = needDelete(oneRoundResult, existResult);
|
||||||
|
if (delete) {
|
||||||
|
log.info("deleted existResult:{}", existResult);
|
||||||
|
}
|
||||||
|
return delete;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (isDeleted) {
|
||||||
|
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
|
||||||
|
existResults.add(oneRoundResult);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
existResults.add(oneRoundResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<T> getMatches(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
Map<MatchText, List<T>> matchResult = match(queryReq, terms, detectModelIds);
|
||||||
|
List<T> matches = new ArrayList<>();
|
||||||
|
if (Objects.isNull(matchResult)) {
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
Optional<List<T>> first = matchResult.entrySet().stream()
|
||||||
|
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||||
|
.map(entry -> entry.getValue()).findFirst();
|
||||||
|
|
||||||
|
if (first.isPresent()) {
|
||||||
|
matches = first.get();
|
||||||
|
}
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract boolean needDelete(T oneRoundResult, T existResult);
|
||||||
|
|
||||||
|
public abstract String getMapKey(T a);
|
||||||
|
|
||||||
|
public abstract void detectByStep(QueryReq queryReq, Set<T> results, Set<Long> detectModelIds, Integer startIndex,
|
||||||
|
Integer index, int offset);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,7 +1,19 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||||
|
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* A mapper that is capable of semantic understanding of text.
|
* A mapper that is capable of semantic understanding of text.
|
||||||
@@ -13,10 +25,44 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
public void work(QueryContext queryContext) {
|
public void work(QueryContext queryContext) {
|
||||||
//1. query from embedding by queryText
|
//1. query from embedding by queryText
|
||||||
|
|
||||||
|
String queryText = queryContext.getRequest().getQueryText();
|
||||||
|
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||||
|
|
||||||
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
|
|
||||||
|
Set<Long> detectModelIds = getModelIds(queryContext);
|
||||||
|
|
||||||
|
terms = filterByModelIds(terms, detectModelIds);
|
||||||
|
|
||||||
|
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds);
|
||||||
|
|
||||||
|
HanlpHelper.transLetterOriginal(matchResults);
|
||||||
|
|
||||||
//2. build SchemaElementMatch by info
|
//2. build SchemaElementMatch by info
|
||||||
|
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||||
|
for (EmbeddingResult matchResult : matchResults) {
|
||||||
|
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||||
|
|
||||||
//3. add to mapInfo
|
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||||
|
SchemaElement.class);
|
||||||
|
|
||||||
|
String modelIdStr = matchResult.getMetadata().get("modelId");
|
||||||
|
if (StringUtils.isBlank(modelIdStr)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
long modelId = Long.parseLong(modelIdStr);
|
||||||
|
|
||||||
|
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
|
||||||
|
|
||||||
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
|
.element(schemaElement)
|
||||||
|
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||||
|
.word(matchResult.getName())
|
||||||
|
.similarity(mapperHelper.getSimilarity(matchResult.getName(), matchResult.getDetectWord()))
|
||||||
|
.detectWord(matchResult.getDetectWord())
|
||||||
|
.build();
|
||||||
|
//3. add to mapInfo
|
||||||
|
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,116 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* match strategy implement
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
@Autowired
|
||||||
|
private EmbeddingUtils embeddingUtils;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||||
|
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||||
|
&& existResult.getDistance() > oneRoundResult.getDistance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey(EmbeddingResult a) {
|
||||||
|
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void detectByStep(QueryReq queryReq, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex, Integer index, int offset) {
|
||||||
|
String detectSegment = queryReq.getQueryText().substring(startIndex, index);
|
||||||
|
// step1. build query params
|
||||||
|
if (StringUtils.isBlank(detectSegment)
|
||||||
|
|| detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMin()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||||
|
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||||
|
Map<String, String> filterCondition = null;
|
||||||
|
|
||||||
|
// if only one modelId, add to filterCondition
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
||||||
|
filterCondition = new HashMap<>();
|
||||||
|
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
|
.queryTextsList(Collections.singletonList(detectSegment))
|
||||||
|
.filterCondition(filterCondition)
|
||||||
|
.queryEmbeddings(null)
|
||||||
|
.build();
|
||||||
|
// step2. retrieveQuery by detectSegment
|
||||||
|
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
||||||
|
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// step3. build EmbeddingResults. filter by modelId
|
||||||
|
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||||
|
.map(retrieveQueryResult -> {
|
||||||
|
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||||
|
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||||
|
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||||
|
retrievals.removeIf(retrieval -> {
|
||||||
|
String modelIdStr = retrieval.getMetadata().get("modelId");
|
||||||
|
if (StringUtils.isBlank(modelIdStr)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return retrieveQueryResult;
|
||||||
|
})
|
||||||
|
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||||
|
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
||||||
|
.map(retrieval -> {
|
||||||
|
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||||
|
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||||
|
embeddingResult.setDetectWord(detectSegment);
|
||||||
|
embeddingResult.setName(retrieval.getQuery());
|
||||||
|
return embeddingResult;
|
||||||
|
}))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
// step4. select mapResul in one round
|
||||||
|
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber();
|
||||||
|
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||||
|
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||||
|
.limit(roundNumber)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
selectResultInOneRound(existResults, oneRoundResults);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,28 +1,22 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.BeanUtils;
|
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* A mapper capable of prefix and suffix similarity parsing for
|
* A mapper capable of prefix and suffix similarity parsing for
|
||||||
@@ -37,45 +31,23 @@ public class HanlpDictMapper extends BaseMapper {
|
|||||||
String queryText = queryContext.getRequest().getQueryText();
|
String queryText = queryContext.getRequest().getQueryText();
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||||
|
|
||||||
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
|
HanlpMatchStrategy matchStrategy = ContextUtils.getBean(HanlpMatchStrategy.class);
|
||||||
|
|
||||||
Set<Long> detectModelIds = ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest());
|
Set<Long> detectModelIds = getModelIds(queryContext);
|
||||||
|
|
||||||
terms = filterByModelIds(terms, detectModelIds);
|
terms = filterByModelIds(terms, detectModelIds);
|
||||||
|
|
||||||
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
|
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds);
|
||||||
detectModelIds);
|
|
||||||
|
|
||||||
List<MapResult> matches = getMatches(matchResult);
|
|
||||||
|
|
||||||
HanlpHelper.transLetterOriginal(matches);
|
HanlpHelper.transLetterOriginal(matches);
|
||||||
|
|
||||||
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
|
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
|
||||||
for (Term term : terms) {
|
|
||||||
log.info("before word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
|
||||||
}
|
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
|
||||||
terms = terms.stream().filter(term -> {
|
|
||||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
|
||||||
if (Objects.nonNull(modelId)) {
|
|
||||||
return detectModelIds.contains(modelId);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
for (Term term : terms) {
|
|
||||||
log.info("after filter word:{},nature:{},frequency:{}", term.word, term.nature.toString(),
|
|
||||||
term.getFrequency());
|
|
||||||
}
|
|
||||||
return terms;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
private void convertTermsToSchemaMapInfo(List<HanlpMapResult> hanlpMapResults, SchemaMapInfo schemaMap,
|
||||||
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
|
List<Term> terms) {
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,8 +55,8 @@ public class HanlpDictMapper extends BaseMapper {
|
|||||||
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||||
|
|
||||||
for (MapResult mapResult : mapResults) {
|
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
|
||||||
for (String nature : mapResult.getNatures()) {
|
for (String nature : hanlpMapResult.getNatures()) {
|
||||||
Long modelId = NatureHelper.getModelId(nature);
|
Long modelId = NatureHelper.getModelId(nature);
|
||||||
if (Objects.isNull(modelId)) {
|
if (Objects.isNull(modelId)) {
|
||||||
continue;
|
continue;
|
||||||
@@ -93,33 +65,21 @@ public class HanlpDictMapper extends BaseMapper {
|
|||||||
if (Objects.isNull(elementType)) {
|
if (Objects.isNull(elementType)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
|
||||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
|
||||||
if (Objects.isNull(modelSchema)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Long elementID = NatureHelper.getElementID(nature);
|
Long elementID = NatureHelper.getElementID(nature);
|
||||||
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
|
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
|
||||||
|
if (element == null) {
|
||||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
|
||||||
if (Objects.isNull(elementDb)) {
|
|
||||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElement element = new SchemaElement();
|
|
||||||
BeanUtils.copyProperties(elementDb, element);
|
|
||||||
element.setAlias(getAlias(elementDb));
|
|
||||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||||
element.setName(mapResult.getName());
|
element.setName(hanlpMapResult.getName());
|
||||||
}
|
}
|
||||||
|
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(element)
|
.element(element)
|
||||||
.frequency(frequency)
|
.frequency(frequency)
|
||||||
.word(mapResult.getName())
|
.word(hanlpMapResult.getName())
|
||||||
.similarity(mapResult.getSimilarity())
|
.similarity(hanlpMapResult.getSimilarity())
|
||||||
.detectWord(mapResult.getDetectWord())
|
.detectWord(hanlpMapResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
||||||
@@ -127,30 +87,5 @@ public class HanlpDictMapper extends BaseMapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<MapResult> getMatches(Map<MatchText, List<MapResult>> matchResult) {
|
|
||||||
List<MapResult> matches = new ArrayList<>();
|
|
||||||
if (Objects.isNull(matchResult)) {
|
|
||||||
return matches;
|
|
||||||
}
|
|
||||||
Optional<List<MapResult>> first = matchResult.entrySet().stream()
|
|
||||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
|
||||||
.map(entry -> entry.getValue()).findFirst();
|
|
||||||
|
|
||||||
if (first.isPresent()) {
|
|
||||||
matches = first.get();
|
|
||||||
}
|
|
||||||
return matches;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getAlias(SchemaElement element) {
|
|
||||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
|
||||||
return element.getAlias();
|
|
||||||
}
|
|
||||||
if (CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(element.getName())) {
|
|
||||||
return element.getAlias().stream()
|
|
||||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
return element.getAlias();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,119 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* match strategy implement
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class HanlpMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<MatchText, List<HanlpMapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
String text = queryReq.getQueryText();
|
||||||
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||||
|
|
||||||
|
List<HanlpMapResult> detects = detect(queryReq, terms, detectModelIds);
|
||||||
|
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||||
|
|
||||||
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||||
|
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||||
|
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void detectByStep(QueryReq queryReq, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex, Integer index, int offset) {
|
||||||
|
String text = queryReq.getQueryText();
|
||||||
|
Integer agentId = queryReq.getAgentId();
|
||||||
|
String detectSegment = text.substring(startIndex, index);
|
||||||
|
|
||||||
|
// step1. pre search
|
||||||
|
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||||
|
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||||
|
agentId,
|
||||||
|
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
// step2. suffix search
|
||||||
|
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
||||||
|
oneDetectionMaxSize,
|
||||||
|
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// step3. merge pre/suffix result
|
||||||
|
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||||
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
// step4. filter by similarity
|
||||||
|
hanlpMapResults = hanlpMapResults.stream()
|
||||||
|
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||||
|
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||||
|
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||||
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
log.info("after isSimilarity parseResults:{}", hanlpMapResults);
|
||||||
|
|
||||||
|
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
|
||||||
|
parseResult.setOffset(offset);
|
||||||
|
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||||
|
return parseResult;
|
||||||
|
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
// step5. take only one dimension or 10 metric/dimension value per rond.
|
||||||
|
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream()
|
||||||
|
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
.stream()
|
||||||
|
.limit(1)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
|
||||||
|
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||||
|
oneRoundResults = dimensionMetrics;
|
||||||
|
}
|
||||||
|
// step6. select mapResul in one round
|
||||||
|
selectResultInOneRound(existResults, oneRoundResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getMapKey(HanlpMapResult a) {
|
||||||
|
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -39,10 +41,14 @@ public class MapperHelper {
|
|||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Integer getStepOffset(List<Integer> termList, Integer index) {
|
|
||||||
|
public Integer getStepOffset(List<Term> termList, Integer index) {
|
||||||
|
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||||
|
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||||
|
|
||||||
for (int j = 0; j < termList.size() - 1; j++) {
|
for (int j = 0; j < termList.size() - 1; j++) {
|
||||||
if (termList.get(j) <= index && termList.get(j + 1) > index) {
|
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
|
||||||
return termList.get(j);
|
return offsetList.get(j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return index;
|
return index;
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.mapper;
|
|||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -10,8 +9,8 @@ import java.util.Set;
|
|||||||
/**
|
/**
|
||||||
* match strategy
|
* match strategy
|
||||||
*/
|
*/
|
||||||
public interface MatchStrategy {
|
public interface MatchStrategy<T> {
|
||||||
|
|
||||||
Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
|
Map<MatchText, List<T>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
|
||||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.LinkedHashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.apache.commons.compress.utils.Lists;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* match strategy implement
|
|
||||||
*/
|
|
||||||
@Service
|
|
||||||
@Slf4j
|
|
||||||
public class QueryMatchStrategy implements MatchStrategy {
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private MapperHelper mapperHelper;
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private OptimizationConfig optimizationConfig;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
|
|
||||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
|
||||||
|
|
||||||
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
|
|
||||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
|
||||||
|
|
||||||
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelIds:{}", terms,
|
|
||||||
regOffsetToLength, offsetList, detectModelIds);
|
|
||||||
|
|
||||||
List<MapResult> detects = detect(queryReq, regOffsetToLength, offsetList, detectModelIds);
|
|
||||||
Map<MatchText, List<MapResult>> result = new HashMap<>();
|
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<MapResult> detect(QueryReq queryReq, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
|
|
||||||
Set<Long> detectModelIds) {
|
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
List<MapResult> results = Lists.newArrayList();
|
|
||||||
|
|
||||||
for (Integer index = 0; index <= text.length() - 1; ) {
|
|
||||||
|
|
||||||
Set<MapResult> mapResultRowSet = new LinkedHashSet();
|
|
||||||
|
|
||||||
for (Integer i = index; i <= text.length(); ) {
|
|
||||||
int offset = mapperHelper.getStepOffset(offsetList, index);
|
|
||||||
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
|
||||||
if (i <= text.length()) {
|
|
||||||
List<MapResult> mapResults = detectByStep(queryReq, detectModelIds, index, i, offset);
|
|
||||||
selectMapResultInOneRound(mapResultRowSet, mapResults);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
|
||||||
results.addAll(mapResultRowSet);
|
|
||||||
}
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void selectMapResultInOneRound(Set<MapResult> mapResultRowSet, List<MapResult> mapResults) {
|
|
||||||
for (MapResult mapResult : mapResults) {
|
|
||||||
if (mapResultRowSet.contains(mapResult)) {
|
|
||||||
boolean isDeleted = mapResultRowSet.removeIf(
|
|
||||||
entry -> {
|
|
||||||
boolean deleted = getMapKey(mapResult).equals(getMapKey(entry))
|
|
||||||
&& entry.getDetectWord().length() < mapResult.getDetectWord().length();
|
|
||||||
if (deleted) {
|
|
||||||
log.info("deleted entry:{}", entry);
|
|
||||||
}
|
|
||||||
return deleted;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
if (isDeleted) {
|
|
||||||
log.info("deleted, add mapResult:{}", mapResult);
|
|
||||||
mapResultRowSet.add(mapResult);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
mapResultRowSet.add(mapResult);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getMapKey(MapResult a) {
|
|
||||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<MapResult> detectByStep(QueryReq queryReq, Set<Long> detectModelIds, Integer index, Integer i,
|
|
||||||
int offset) {
|
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
Integer agentId = queryReq.getAgentId();
|
|
||||||
String detectSegment = text.substring(index, i);
|
|
||||||
|
|
||||||
// step1. pre search
|
|
||||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
|
||||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
|
|
||||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
// step2. suffix search
|
|
||||||
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize,
|
|
||||||
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
mapResults.addAll(suffixMapResults);
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
|
||||||
return new ArrayList<>();
|
|
||||||
}
|
|
||||||
// step3. merge pre/suffix result
|
|
||||||
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
// step4. filter by similarity
|
|
||||||
mapResults = mapResults.stream()
|
|
||||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
|
||||||
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
|
||||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
log.info("after isSimilarity parseResults:{}", mapResults);
|
|
||||||
|
|
||||||
mapResults = mapResults.stream().map(parseResult -> {
|
|
||||||
parseResult.setOffset(offset);
|
|
||||||
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
|
||||||
return parseResult;
|
|
||||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
// step5. take only one dimension or 10 metric/dimension value per rond.
|
|
||||||
List<MapResult> dimensionMetrics = mapResults.stream()
|
|
||||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
|
||||||
.collect(Collectors.toList())
|
|
||||||
.stream()
|
|
||||||
.limit(1)
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
|
||||||
return dimensionMetrics;
|
|
||||||
} else {
|
|
||||||
return mapResults.stream().limit(optimizationConfig.getOneDetectionSize()).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -4,7 +4,7 @@ import com.google.common.collect.Lists;
|
|||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -20,17 +20,15 @@ import org.springframework.stereotype.Service;
|
|||||||
* match strategy implement
|
* match strategy implement
|
||||||
*/
|
*/
|
||||||
@Service
|
@Service
|
||||||
public class SearchMatchStrategy implements MatchStrategy {
|
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||||
|
|
||||||
private static final int SEARCH_SIZE = 3;
|
private static final int SEARCH_SIZE = 3;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> originals, Set<Long> detectModelIds) {
|
public Map<MatchText, List<HanlpMapResult>> match(QueryReq queryReq, List<Term> originals,
|
||||||
|
Set<Long> detectModelIds) {
|
||||||
String text = queryReq.getQueryText();
|
String text = queryReq.getQueryText();
|
||||||
Map<Integer, Integer> regOffsetToLength = originals.stream()
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||||
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
|
|
||||||
.collect(Collectors.toMap(Term::getOffset, value -> value.word.length(),
|
|
||||||
(value1, value2) -> value2));
|
|
||||||
|
|
||||||
List<Integer> detectIndexList = Lists.newArrayList();
|
List<Integer> detectIndexList = Lists.newArrayList();
|
||||||
|
|
||||||
@@ -46,19 +44,19 @@ public class SearchMatchStrategy implements MatchStrategy {
|
|||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Map<MatchText, List<MapResult>> regTextMap = new ConcurrentHashMap<>();
|
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
|
||||||
detectIndexList.stream().parallel().forEach(detectIndex -> {
|
detectIndexList.stream().parallel().forEach(detectIndex -> {
|
||||||
String regText = text.substring(0, detectIndex);
|
String regText = text.substring(0, detectIndex);
|
||||||
String detectSegment = text.substring(detectIndex);
|
String detectSegment = text.substring(detectIndex);
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||||
List<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
|
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
||||||
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||||
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE,
|
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
||||||
queryReq.getAgentId(), detectModelIds);
|
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||||
mapResults.addAll(suffixMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
// remove entity name where search
|
// remove entity name where search
|
||||||
mapResults = mapResults.stream().filter(entry -> {
|
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||||
List<String> natures = entry.getNatures().stream()
|
List<String> natures = entry.getNatures().stream()
|
||||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
@@ -71,10 +69,27 @@ public class SearchMatchStrategy implements MatchStrategy {
|
|||||||
.regText(regText)
|
.regText(regText)
|
||||||
.detectSegment(detectSegment)
|
.detectSegment(detectSegment)
|
||||||
.build();
|
.build();
|
||||||
regTextMap.put(matchText, mapResults);
|
regTextMap.put(matchText, hanlpMapResults);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
return regTextMap;
|
return regTextMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey(HanlpMapResult a) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void detectByStep(QueryReq queryReq, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex,
|
||||||
|
Integer i, int offset) {
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ public class SimilarMetricExecuteResponder implements ExecuteResponder {
|
|||||||
metric.setOrder(metricOrder++);
|
metric.setOrder(metricOrder++);
|
||||||
}
|
}
|
||||||
for (Retrieval retrieval : retrievals) {
|
for (Retrieval retrieval : retrievals) {
|
||||||
if (!metricIds.contains(retrieval.getId())) {
|
if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) {
|
||||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
|
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
|
||||||
SchemaElement.class);
|
SchemaElement.class);
|
||||||
if (retrieval.getMetadata().containsKey("modelId")) {
|
if (retrieval.getMetadata().containsKey("modelId")) {
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
|||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
|
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
|
||||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||||
@@ -628,10 +628,10 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
|
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
//search from prefixSearch
|
//search from prefixSearch
|
||||||
List<MapResult> mapResultList = SearchService.prefixSearch(dimensionValueReq.getValue(),
|
List<HanlpMapResult> hanlpMapResultList = SearchService.prefixSearch(dimensionValueReq.getValue(),
|
||||||
2000, dimensionValueReq.getAgentId(), detectModelIds);
|
2000, dimensionValueReq.getAgentId(), detectModelIds);
|
||||||
HanlpHelper.transLetterOriginal(mapResultList);
|
HanlpHelper.transLetterOriginal(hanlpMapResultList);
|
||||||
return mapResultList.stream()
|
return hanlpMapResultList.stream()
|
||||||
.filter(o -> {
|
.filter(o -> {
|
||||||
for (String nature : o.getNatures()) {
|
for (String nature : o.getNatures()) {
|
||||||
Long elementID = NatureHelper.getElementID(nature);
|
Long elementID = NatureHelper.getElementID(nature);
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import com.tencent.supersonic.chat.service.ChatService;
|
|||||||
import com.tencent.supersonic.chat.service.SearchService;
|
import com.tencent.supersonic.chat.service.SearchService;
|
||||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||||
@@ -94,11 +94,12 @@ public class SearchServiceImpl implements SearchService {
|
|||||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq);
|
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq);
|
||||||
|
|
||||||
Map<MatchText, List<MapResult>> regTextMap = searchMatchStrategy.match(queryReq, originals, detectModelIds);
|
Map<MatchText, List<HanlpMapResult>> regTextMap =
|
||||||
|
searchMatchStrategy.match(queryReq, originals, detectModelIds);
|
||||||
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
||||||
|
|
||||||
// 4.get the most matching data
|
// 4.get the most matching data
|
||||||
Optional<Entry<MatchText, List<MapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
|
Optional<Entry<MatchText, List<HanlpMapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
|
||||||
.stream()
|
.stream()
|
||||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||||
.reduce((entry1, entry2) ->
|
.reduce((entry1, entry2) ->
|
||||||
@@ -109,7 +110,7 @@ public class SearchServiceImpl implements SearchService {
|
|||||||
if (!mostSimilarSearchResult.isPresent()) {
|
if (!mostSimilarSearchResult.isPresent()) {
|
||||||
return Lists.newArrayList();
|
return Lists.newArrayList();
|
||||||
}
|
}
|
||||||
Map.Entry<MatchText, List<MapResult>> searchTextEntry = mostSimilarSearchResult.get();
|
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry = mostSimilarSearchResult.get();
|
||||||
log.info("searchTextEntry:{},queryReq:{}", searchTextEntry, queryReq);
|
log.info("searchTextEntry:{},queryReq:{}", searchTextEntry, queryReq);
|
||||||
|
|
||||||
Set<SearchResult> searchResults = new LinkedHashSet();
|
Set<SearchResult> searchResults = new LinkedHashSet();
|
||||||
@@ -275,9 +276,9 @@ public class SearchServiceImpl implements SearchService {
|
|||||||
* @param recommendTextListEntry
|
* @param recommendTextListEntry
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
private Map<String, String> getNatureToNameMap(Map.Entry<MatchText, List<MapResult>> recommendTextListEntry,
|
private Map<String, String> getNatureToNameMap(Map.Entry<MatchText, List<HanlpMapResult>> recommendTextListEntry,
|
||||||
Set<Long> possibleModels) {
|
Set<Long> possibleModels) {
|
||||||
List<MapResult> recommendValues = recommendTextListEntry.getValue();
|
List<HanlpMapResult> recommendValues = recommendTextListEntry.getValue();
|
||||||
return recommendValues.stream()
|
return recommendValues.stream()
|
||||||
.flatMap(entry -> entry.getNatures().stream()
|
.flatMap(entry -> entry.getNatures().stream()
|
||||||
.filter(nature -> {
|
.filter(nature -> {
|
||||||
@@ -288,26 +289,25 @@ public class SearchServiceImpl implements SearchService {
|
|||||||
return possibleModels.contains(model);
|
return possibleModels.contains(model);
|
||||||
})
|
})
|
||||||
.map(nature -> {
|
.map(nature -> {
|
||||||
DictWord posDO = new DictWord();
|
DictWord posDO = new DictWord();
|
||||||
posDO.setWord(entry.getName());
|
posDO.setWord(entry.getName());
|
||||||
posDO.setNature(nature);
|
posDO.setNature(nature);
|
||||||
return posDO;
|
return posDO;
|
||||||
}
|
})).sorted(Comparator.comparingInt(a -> a.getWord().length()))
|
||||||
)).sorted(Comparator.comparingInt(a -> a.getWord().length()))
|
|
||||||
.collect(Collectors.toMap(DictWord::getNature, DictWord::getWord, (value1, value2) -> value1,
|
.collect(Collectors.toMap(DictWord::getNature, DictWord::getWord, (value1, value2) -> value1,
|
||||||
LinkedHashMap::new));
|
LinkedHashMap::new));
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean searchMetricAndDimension(Set<Long> possibleModels, Map<Long, String> modelToName,
|
private boolean searchMetricAndDimension(Set<Long> possibleModels, Map<Long, String> modelToName,
|
||||||
Map.Entry<MatchText, List<MapResult>> searchTextEntry, Set<SearchResult> searchResults) {
|
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry, Set<SearchResult> searchResults) {
|
||||||
boolean existMetric = false;
|
boolean existMetric = false;
|
||||||
log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry);
|
log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry);
|
||||||
MatchText matchText = searchTextEntry.getKey();
|
MatchText matchText = searchTextEntry.getKey();
|
||||||
List<MapResult> mapResults = searchTextEntry.getValue();
|
List<HanlpMapResult> hanlpMapResults = searchTextEntry.getValue();
|
||||||
|
|
||||||
for (MapResult mapResult : mapResults) {
|
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
|
||||||
|
|
||||||
List<ModelWithSemanticType> dimensionMetricClassIds = mapResult.getNatures().stream()
|
List<ModelWithSemanticType> dimensionMetricClassIds = hanlpMapResult.getNatures().stream()
|
||||||
.map(nature -> new ModelWithSemanticType(NatureHelper.getModelId(nature),
|
.map(nature -> new ModelWithSemanticType(NatureHelper.getModelId(nature),
|
||||||
NatureHelper.convertToElementType(nature)))
|
NatureHelper.convertToElementType(nature)))
|
||||||
.filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList());
|
.filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList());
|
||||||
@@ -322,8 +322,8 @@ public class SearchServiceImpl implements SearchService {
|
|||||||
SearchResult searchResult = SearchResult.builder()
|
SearchResult searchResult = SearchResult.builder()
|
||||||
.modelId(modelId)
|
.modelId(modelId)
|
||||||
.modelName(modelToName.get(modelId))
|
.modelName(modelToName.get(modelId))
|
||||||
.recommend(matchText.getRegText() + mapResult.getName())
|
.recommend(matchText.getRegText() + hanlpMapResult.getName())
|
||||||
.subRecommend(mapResult.getName())
|
.subRecommend(hanlpMapResult.getName())
|
||||||
.schemaElementType(semanticType)
|
.schemaElementType(semanticType)
|
||||||
.build();
|
.build();
|
||||||
//visibility to filter metrics
|
//visibility to filter metrics
|
||||||
@@ -332,13 +332,13 @@ public class SearchServiceImpl implements SearchService {
|
|||||||
visibility = configService.getVisibilityByModelId(modelId);
|
visibility = configService.getVisibilityByModelId(modelId);
|
||||||
caffeineCache.put(modelId, visibility);
|
caffeineCache.put(modelId, visibility);
|
||||||
}
|
}
|
||||||
if (!visibility.getBlackMetricNameList().contains(mapResult.getName())
|
if (!visibility.getBlackMetricNameList().contains(hanlpMapResult.getName())
|
||||||
&& !visibility.getBlackDimNameList().contains(mapResult.getName())) {
|
&& !visibility.getBlackDimNameList().contains(hanlpMapResult.getName())) {
|
||||||
searchResults.add(searchResult);
|
searchResults.add(searchResult);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", mapResult, dimensionMetricClassIds,
|
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", hanlpMapResult,
|
||||||
possibleModels);
|
dimensionMetricClassIds, possibleModels);
|
||||||
}
|
}
|
||||||
log.info("searchMetricAndDimension searchResults:{}", searchResults);
|
log.info("searchMetricAndDimension searchResults:{}", searchResults);
|
||||||
return existMetric;
|
return existMetric;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import org.junit.jupiter.api.Test;
|
|||||||
/**
|
/**
|
||||||
* MatchStrategyImplTest
|
* MatchStrategyImplTest
|
||||||
*/
|
*/
|
||||||
class QueryMatchStrategyTest extends ContextTest {
|
class HanlpMatchStrategyTest extends ContextTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void match() {
|
void match() {
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package com.tencent.supersonic.knowledge.dictionary;
|
||||||
|
|
||||||
|
import com.google.common.base.Objects;
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@ToString
|
||||||
|
public class EmbeddingResult extends MapResult {
|
||||||
|
|
||||||
|
private String id;
|
||||||
|
|
||||||
|
private double distance;
|
||||||
|
|
||||||
|
private Map<String, String> metadata;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (o == null || getClass() != o.getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
EmbeddingResult that = (EmbeddingResult) o;
|
||||||
|
return Objects.equal(id, that.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hashCode(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package com.tencent.supersonic.knowledge.dictionary;
|
||||||
|
|
||||||
|
import com.google.common.base.Objects;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@ToString
|
||||||
|
public class HanlpMapResult extends MapResult {
|
||||||
|
|
||||||
|
private List<String> natures;
|
||||||
|
private int offset = 0;
|
||||||
|
|
||||||
|
private double similarity;
|
||||||
|
|
||||||
|
public HanlpMapResult(String name, List<String> natures, String detectWord) {
|
||||||
|
this.name = name;
|
||||||
|
this.natures = natures;
|
||||||
|
this.detectWord = detectWord;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (o == null || getClass() != o.getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
HanlpMapResult hanlpMapResult = (HanlpMapResult) o;
|
||||||
|
return Objects.equal(name, hanlpMapResult.name) && Objects.equal(natures, hanlpMapResult.natures);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hashCode(name, natures);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setOffset(int offset) {
|
||||||
|
this.offset = offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
package com.tencent.supersonic.knowledge.dictionary;
|
package com.tencent.supersonic.knowledge.dictionary;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
@@ -10,43 +8,6 @@ import lombok.ToString;
|
|||||||
@ToString
|
@ToString
|
||||||
public class MapResult implements Serializable {
|
public class MapResult implements Serializable {
|
||||||
|
|
||||||
private String name;
|
protected String name;
|
||||||
private List<String> natures;
|
protected String detectWord;
|
||||||
private int offset = 0;
|
|
||||||
|
|
||||||
private double similarity;
|
|
||||||
|
|
||||||
private String detectWord;
|
|
||||||
|
|
||||||
public MapResult() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public MapResult(String name, List<String> natures, String detectWord) {
|
|
||||||
this.name = name;
|
|
||||||
this.natures = natures;
|
|
||||||
this.detectWord = detectWord;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (o == null || getClass() != o.getClass()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MapResult mapResult = (MapResult) o;
|
|
||||||
return Objects.equal(name, mapResult.name) && Objects.equal(natures, mapResult.natures);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return Objects.hashCode(name, natures);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setOffset(int offset) {
|
|
||||||
this.offset = offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -7,7 +7,7 @@ import com.hankcs.hanlp.dictionary.CoreDictionary;
|
|||||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.DictionaryAttributeUtil;
|
import com.tencent.supersonic.knowledge.dictionary.DictionaryAttributeUtil;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -38,17 +38,17 @@ public class SearchService {
|
|||||||
* @param key
|
* @param key
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<MapResult> prefixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
|
public static List<HanlpMapResult> prefixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
|
||||||
return prefixSearch(key, limit, agentId, trie, detectModelIds);
|
return prefixSearch(key, limit, agentId, trie, detectModelIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<MapResult> prefixSearch(String key, int limit, Integer agentId, BinTrie<List<String>> binTrie,
|
public static List<HanlpMapResult> prefixSearch(String key, int limit, Integer agentId,
|
||||||
Set<Long> detectModelIds) {
|
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
|
||||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
|
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
|
||||||
return result.stream().map(
|
return result.stream().map(
|
||||||
entry -> {
|
entry -> {
|
||||||
String name = entry.getKey().replace("#", " ");
|
String name = entry.getKey().replace("#", " ");
|
||||||
return new MapResult(name, entry.getValue(), key);
|
return new HanlpMapResult(name, entry.getValue(), key);
|
||||||
}
|
}
|
||||||
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||||
.limit(SEARCH_SIZE)
|
.limit(SEARCH_SIZE)
|
||||||
@@ -60,13 +60,13 @@ public class SearchService {
|
|||||||
* @param key
|
* @param key
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<MapResult> suffixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
|
public static List<HanlpMapResult> suffixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
|
||||||
String reverseDetectSegment = StringUtils.reverse(key);
|
String reverseDetectSegment = StringUtils.reverse(key);
|
||||||
return suffixSearch(reverseDetectSegment, limit, agentId, suffixTrie, detectModelIds);
|
return suffixSearch(reverseDetectSegment, limit, agentId, suffixTrie, detectModelIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<MapResult> suffixSearch(String key, int limit, Integer agentId, BinTrie<List<String>> binTrie,
|
public static List<HanlpMapResult> suffixSearch(String key, int limit, Integer agentId,
|
||||||
Set<Long> detectModelIds) {
|
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
|
||||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
|
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
|
||||||
return result.stream().map(
|
return result.stream().map(
|
||||||
entry -> {
|
entry -> {
|
||||||
@@ -75,7 +75,7 @@ public class SearchService {
|
|||||||
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
|
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
name = StringUtils.reverse(name);
|
name = StringUtils.reverse(name);
|
||||||
return new MapResult(name, natures, key);
|
return new HanlpMapResult(name, natures, key);
|
||||||
}
|
}
|
||||||
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||||
.limit(SEARCH_SIZE)
|
.limit(SEARCH_SIZE)
|
||||||
|
|||||||
@@ -9,17 +9,16 @@ import com.hankcs.hanlp.seg.Segment;
|
|||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.HadoopFileIOAdapter;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.FileNotFoundException;
|
import java.io.FileNotFoundException;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
|
||||||
import com.tencent.supersonic.knowledge.dictionary.HadoopFileIOAdapter;
|
|
||||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -186,11 +185,11 @@ public class HanlpHelper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void transLetterOriginal(List<MapResult> mapResults) {
|
public static <T extends MapResult> void transLetterOriginal(List<T> mapResults) {
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
if (CollectionUtils.isEmpty(mapResults)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (MapResult mapResult : mapResults) {
|
for (T mapResult : mapResults) {
|
||||||
if (MultiCustomDictionary.isLowerLetter(mapResult.getName())) {
|
if (MultiCustomDictionary.isLowerLetter(mapResult.getName())) {
|
||||||
if (CustomDictionary.contains(mapResult.getName())) {
|
if (CustomDictionary.contains(mapResult.getName())) {
|
||||||
CoreDictionary.Attribute attribute = CustomDictionary.get(mapResult.getName());
|
CoreDictionary.Attribute attribute = CustomDictionary.get(mapResult.getName());
|
||||||
|
|||||||
@@ -1,18 +1,28 @@
|
|||||||
package com.tencent.supersonic.common.util.embedding;
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class Retrieval {
|
public class Retrieval {
|
||||||
|
|
||||||
private Long id;
|
protected String id;
|
||||||
|
|
||||||
private double distance;
|
protected double distance;
|
||||||
|
|
||||||
private String query;
|
protected String query;
|
||||||
|
|
||||||
private Map<String, String> metadata;
|
protected Map<String, String> metadata;
|
||||||
|
|
||||||
|
|
||||||
|
public static Long getLongId(String id) {
|
||||||
|
if (StringUtils.isBlank(id)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String[] split = id.split(DictWordType.NATURE_SPILT);
|
||||||
|
return Long.parseLong(split[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -264,7 +264,8 @@ class SqlParserAddHelperTest {
|
|||||||
+ "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10",
|
+ "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10",
|
||||||
replaceSql);
|
replaceSql);
|
||||||
|
|
||||||
sql = "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11' and count(DISTINCT uv) >1 "
|
sql = "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11'"
|
||||||
|
+ " and count(DISTINCT uv) >1 "
|
||||||
+ "GROUP BY department order by count(DISTINCT uv) desc limit 10";
|
+ "GROUP BY department order by count(DISTINCT uv) desc limit 10";
|
||||||
replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate);
|
replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate);
|
||||||
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
|
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
|
||||||
@@ -290,7 +291,8 @@ class SqlParserAddHelperTest {
|
|||||||
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
|
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND count(DISTINCT uv) > 1 "
|
"SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = "
|
||||||
|
+ "'2023-09-11' AND count(DISTINCT uv) > 1 "
|
||||||
+ "AND department = 'HR' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10",
|
+ "AND department = 'HR' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10",
|
||||||
replaceSql);
|
replaceSql);
|
||||||
|
|
||||||
@@ -300,8 +302,10 @@ class SqlParserAddHelperTest {
|
|||||||
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
|
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT department, count(DISTINCT uv) FROM t_1 WHERE (count(DISTINCT uv) > 1 AND department = 'HR') AND "
|
"SELECT department, count(DISTINCT uv) FROM t_1 WHERE (count(DISTINCT uv) > "
|
||||||
+ "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10",
|
+ "1 AND department = 'HR') AND "
|
||||||
|
+ "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY "
|
||||||
|
+ "count(DISTINCT uv) DESC LIMIT 10",
|
||||||
replaceSql);
|
replaceSql);
|
||||||
|
|
||||||
sql = "select department, count(DISTINCT uv) as uv from t_1 where sys_imp_date = '2023-09-11' GROUP BY "
|
sql = "select department, count(DISTINCT uv) as uv from t_1 where sys_imp_date = '2023-09-11' GROUP BY "
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
||||||
|
com.tencent.supersonic.chat.mapper.EmbeddingMapper, \
|
||||||
com.tencent.supersonic.chat.mapper.HanlpDictMapper, \
|
com.tencent.supersonic.chat.mapper.HanlpDictMapper, \
|
||||||
com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \
|
com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \
|
||||||
com.tencent.supersonic.chat.mapper.QueryFilterMapper, \
|
com.tencent.supersonic.chat.mapper.QueryFilterMapper, \
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
||||||
|
com.tencent.supersonic.chat.mapper.EmbeddingMapper, \
|
||||||
com.tencent.supersonic.chat.mapper.HanlpDictMapper, \
|
com.tencent.supersonic.chat.mapper.HanlpDictMapper, \
|
||||||
com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \
|
com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \
|
||||||
com.tencent.supersonic.chat.mapper.QueryFilterMapper, \
|
com.tencent.supersonic.chat.mapper.QueryFilterMapper, \
|
||||||
|
|||||||
@@ -2,18 +2,18 @@ package com.tencent.supersonic.semantic.model.domain.listener;
|
|||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.context.ApplicationListener;
|
import org.springframework.context.ApplicationListener;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -30,15 +30,17 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<EmbeddingQuery> embeddingQueries = event.getDataItems()
|
List<EmbeddingQuery> embeddingQueries = event.getDataItems()
|
||||||
.stream().filter(dataItem -> dataItem.getType().equals(TypeEnums.METRIC)).map(dataItem -> {
|
.stream()
|
||||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
.map(dataItem -> {
|
||||||
embeddingQuery.setQueryId(dataItem.getId().toString());
|
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||||
embeddingQuery.setQuery(dataItem.getName());
|
embeddingQuery.setQueryId(
|
||||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
dataItem.getId().toString() + DictWordType.NATURE_SPILT + dataItem.getType().getName());
|
||||||
embeddingQuery.setMetadata(meta);
|
embeddingQuery.setQuery(dataItem.getName());
|
||||||
embeddingQuery.setQueryEmbedding(null);
|
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||||
return embeddingQuery;
|
embeddingQuery.setMetadata(meta);
|
||||||
}).collect(Collectors.toList());
|
embeddingQuery.setQueryEmbedding(null);
|
||||||
|
return embeddingQuery;
|
||||||
|
}).collect(Collectors.toList());
|
||||||
if (CollectionUtils.isEmpty(embeddingQueries)) {
|
if (CollectionUtils.isEmpty(embeddingQueries)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user