(improvement)(chat) Support semantic understanding, optimize the overall code of the mapper. (#321)

This commit is contained in:
lexluo09
2023-11-05 13:14:07 +08:00
committed by GitHub
parent 2fe56e7462
commit 910384d17f
25 changed files with 716 additions and 375 deletions

View File

@@ -42,4 +42,18 @@ public class OptimizationConfig {
@Value("${user.s2ql.switch:false}")
private boolean useS2qlSwitch;
@Value("${embedding.mapper.word.min:2}")
private int embeddingMapperWordMin;
@Value("${embedding.mapper.number:10}")
private int embeddingMapperNumber;
@Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;
@Value("${embedding.mapper.sum.number:10}")
private int embeddingMapperSumNumber;
@Value("${embedding.mapper.distance.threshold:0.3}")
private Double embeddingMapperDistanceThreshold;
}

View File

@@ -1,13 +1,26 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
/**
* base Mapper
@@ -19,12 +32,17 @@ public abstract class BaseMapper implements SchemaMapper {
public void map(QueryContext queryContext) {
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo());
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo());
try {
work(queryContext);
} catch (Exception e) {
log.error("work error", e);
}
work(queryContext);
log.debug("after {},mapInfo:{}", simpleName, queryContext.getMapInfo());
long cost = System.currentTimeMillis() - startTime;
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo());
}
public abstract void work(QueryContext queryContext);
@@ -38,4 +56,63 @@ public abstract class BaseMapper implements SchemaMapper {
}
schemaElementMatches.add(schemaElementMatch);
}
public Set<Long> getModelIds(QueryContext queryContext) {
return ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest());
}
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
logTerms(terms);
if (CollectionUtils.isNotEmpty(detectModelIds)) {
terms = terms.stream().filter(term -> {
Long modelId = NatureHelper.getModelId(term.getNature().toString());
if (Objects.nonNull(modelId)) {
return detectModelIds.contains(modelId);
}
return false;
}).collect(Collectors.toList());
log.info("terms filter by modelIds:{}", detectModelIds);
logTerms(terms);
}
return terms;
}
public void logTerms(List<Term> terms) {
if (CollectionUtils.isEmpty(terms)) {
return;
}
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
}
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
SchemaElement element = new SchemaElement();
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
if (Objects.isNull(modelSchema)) {
return null;
}
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
if (Objects.isNull(elementDb)) {
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
return null;
}
BeanUtils.copyProperties(elementDb, element);
element.setAlias(getAlias(elementDb));
return element;
}
public List<String> getAlias(SchemaElement element) {
if (!SchemaElementType.VALUE.equals(element.getType())) {
return element.getAlias();
}
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
element.getName())) {
return element.getAlias().stream()
.filter(aliasItem -> aliasItem.contains(element.getName()))
.collect(Collectors.toList());
}
return element.getAlias();
}
}

View File

@@ -0,0 +1,121 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* match strategy implement
*/
@Service
@Slf4j
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Autowired
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<T>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
List<T> detects = detect(queryReq, terms, detectModelIds);
Map<MatchText, List<T>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
public List<T> detect(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
String text = queryReq.getQueryText();
Set<T> results = new HashSet<>();
for (Integer index = 0; index <= text.length() - 1; ) {
for (Integer i = index; i <= text.length(); ) {
int offset = mapperHelper.getStepOffset(terms, index);
i = mapperHelper.getStepIndex(regOffsetToLength, i);
if (i <= text.length()) {
detectByStep(queryReq, results, detectModelIds, index, i, offset);
}
}
index = mapperHelper.getStepIndex(regOffsetToLength, index);
}
return new ArrayList<>(results);
}
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
return terms.stream().sorted(Comparator.comparing(Term::length))
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
}
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
if (CollectionUtils.isEmpty(oneRoundResults)) {
return;
}
for (T oneRoundResult : oneRoundResults) {
if (existResults.contains(oneRoundResult)) {
boolean isDeleted = existResults.removeIf(
existResult -> {
boolean delete = needDelete(oneRoundResult, existResult);
if (delete) {
log.info("deleted existResult:{}", existResult);
}
return delete;
}
);
if (isDeleted) {
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
existResults.add(oneRoundResult);
}
} else {
existResults.add(oneRoundResult);
}
}
}
public List<T> getMatches(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
Map<MatchText, List<T>> matchResult = match(queryReq, terms, detectModelIds);
List<T> matches = new ArrayList<>();
if (Objects.isNull(matchResult)) {
return matches;
}
Optional<List<T>> first = matchResult.entrySet().stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
.map(entry -> entry.getValue()).findFirst();
if (first.isPresent()) {
matches = first.get();
}
return matches;
}
public abstract boolean needDelete(T oneRoundResult, T existResult);
public abstract String getMapKey(T a);
public abstract void detectByStep(QueryReq queryReq, Set<T> results, Set<Long> detectModelIds, Integer startIndex,
Integer index, int offset);
}

View File

@@ -1,7 +1,19 @@
package com.tencent.supersonic.chat.mapper;
import com.alibaba.fastjson.JSONObject;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.List;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/***
* A mapper that is capable of semantic understanding of text.
@@ -13,10 +25,44 @@ public class EmbeddingMapper extends BaseMapper {
public void work(QueryContext queryContext) {
//1. query from embedding by queryText
String queryText = queryContext.getRequest().getQueryText();
List<Term> terms = HanlpHelper.getTerms(queryText);
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
Set<Long> detectModelIds = getModelIds(queryContext);
terms = filterByModelIds(terms, detectModelIds);
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds);
HanlpHelper.transLetterOriginal(matchResults);
//2. build SchemaElementMatch by info
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId());
//3. add to mapInfo
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
SchemaElement.class);
String modelIdStr = matchResult.getMetadata().get("modelId");
if (StringUtils.isBlank(modelIdStr)) {
continue;
}
long modelId = Long.parseLong(modelIdStr);
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(matchResult.getName())
.similarity(mapperHelper.getSimilarity(matchResult.getName(), matchResult.getDetectWord()))
.detectWord(matchResult.getDetectWord())
.build();
//3. add to mapInfo
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
}
}
}

View File

@@ -0,0 +1,116 @@
package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* match strategy implement
*/
@Service
@Slf4j
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private EmbeddingUtils embeddingUtils;
@Override
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
&& existResult.getDistance() > oneRoundResult.getDistance();
}
@Override
public String getMapKey(EmbeddingResult a) {
return a.getName() + Constants.UNDERLINE + a.getId();
}
public void detectByStep(QueryReq queryReq, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) {
String detectSegment = queryReq.getQueryText().substring(startIndex, index);
// step1. build query params
if (StringUtils.isBlank(detectSegment)
|| detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMin()) {
return;
}
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
Map<String, String> filterCondition = null;
// if only one modelId, add to filterCondition
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
filterCondition = new HashMap<>();
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
}
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
.queryTextsList(Collections.singletonList(detectSegment))
.filterCondition(filterCondition)
.queryEmbeddings(null)
.build();
// step2. retrieveQuery by detectSegment
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
}
// step3. build EmbeddingResults. filter by modelId
List<EmbeddingResult> collect = retrieveQueryResults.stream()
.map(retrieveQueryResult -> {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
if (CollectionUtils.isNotEmpty(retrievals)) {
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
if (CollectionUtils.isNotEmpty(detectModelIds)) {
retrievals.removeIf(retrieval -> {
String modelIdStr = retrieval.getMetadata().get("modelId");
if (StringUtils.isBlank(modelIdStr)) {
return true;
}
return detectModelIds.contains(Long.parseLong(modelIdStr));
});
}
}
return retrieveQueryResult;
})
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
.map(retrieval -> {
EmbeddingResult embeddingResult = new EmbeddingResult();
BeanUtils.copyProperties(retrieval, embeddingResult);
embeddingResult.setDetectWord(detectSegment);
embeddingResult.setName(retrieval.getQuery());
return embeddingResult;
}))
.collect(Collectors.toList());
// step4. select mapResul in one round
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber();
List<EmbeddingResult> oneRoundResults = collect.stream()
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
.limit(roundNumber)
.collect(Collectors.toList());
selectResultInOneRound(existResults, oneRoundResults);
}
}

View File

@@ -1,28 +1,22 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
/***
* A mapper capable of prefix and suffix similarity parsing for
@@ -37,45 +31,23 @@ public class HanlpDictMapper extends BaseMapper {
String queryText = queryContext.getRequest().getQueryText();
List<Term> terms = HanlpHelper.getTerms(queryText);
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
HanlpMatchStrategy matchStrategy = ContextUtils.getBean(HanlpMatchStrategy.class);
Set<Long> detectModelIds = ContextUtils.getBean(MapperHelper.class).getModelIds(queryContext.getRequest());
Set<Long> detectModelIds = getModelIds(queryContext);
terms = filterByModelIds(terms, detectModelIds);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
detectModelIds);
List<MapResult> matches = getMatches(matchResult);
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext.getRequest(), terms, detectModelIds);
HanlpHelper.transLetterOriginal(matches);
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
}
private List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
for (Term term : terms) {
log.info("before word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
if (CollectionUtils.isNotEmpty(detectModelIds)) {
terms = terms.stream().filter(term -> {
Long modelId = NatureHelper.getModelId(term.getNature().toString());
if (Objects.nonNull(modelId)) {
return detectModelIds.contains(modelId);
}
return false;
}).collect(Collectors.toList());
}
for (Term term : terms) {
log.info("after filter word:{},nature:{},frequency:{}", term.word, term.nature.toString(),
term.getFrequency());
}
return terms;
}
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {
private void convertTermsToSchemaMapInfo(List<HanlpMapResult> hanlpMapResults, SchemaMapInfo schemaMap,
List<Term> terms) {
if (CollectionUtils.isEmpty(hanlpMapResults)) {
return;
}
@@ -83,8 +55,8 @@ public class HanlpDictMapper extends BaseMapper {
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
for (MapResult mapResult : mapResults) {
for (String nature : mapResult.getNatures()) {
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
for (String nature : hanlpMapResult.getNatures()) {
Long modelId = NatureHelper.getModelId(nature);
if (Objects.isNull(modelId)) {
continue;
@@ -93,33 +65,21 @@ public class HanlpDictMapper extends BaseMapper {
if (Objects.isNull(elementType)) {
continue;
}
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
if (Objects.isNull(modelSchema)) {
return;
}
Long elementID = NatureHelper.getElementID(nature);
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
if (Objects.isNull(elementDb)) {
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
if (element == null) {
continue;
}
SchemaElement element = new SchemaElement();
BeanUtils.copyProperties(elementDb, element);
element.setAlias(getAlias(elementDb));
if (element.getType().equals(SchemaElementType.VALUE)) {
element.setName(mapResult.getName());
element.setName(hanlpMapResult.getName());
}
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
.frequency(frequency)
.word(mapResult.getName())
.similarity(mapResult.getSimilarity())
.detectWord(mapResult.getDetectWord())
.word(hanlpMapResult.getName())
.similarity(hanlpMapResult.getSimilarity())
.detectWord(hanlpMapResult.getDetectWord())
.build();
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
@@ -127,30 +87,5 @@ public class HanlpDictMapper extends BaseMapper {
}
}
private List<MapResult> getMatches(Map<MatchText, List<MapResult>> matchResult) {
List<MapResult> matches = new ArrayList<>();
if (Objects.isNull(matchResult)) {
return matches;
}
Optional<List<MapResult>> first = matchResult.entrySet().stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
.map(entry -> entry.getValue()).findFirst();
if (first.isPresent()) {
matches = first.get();
}
return matches;
}
public List<String> getAlias(SchemaElement element) {
if (!SchemaElementType.VALUE.equals(element.getType())) {
return element.getAlias();
}
if (CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(element.getName())) {
return element.getAlias().stream()
.filter(aliasItem -> aliasItem.contains(element.getName()))
.collect(Collectors.toList());
}
return element.getAlias();
}
}

View File

@@ -0,0 +1,119 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* match strategy implement
*/
@Service
@Slf4j
public class HanlpMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Autowired
private MapperHelper mapperHelper;
@Autowired
private OptimizationConfig optimizationConfig;
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
List<HanlpMapResult> detects = detect(queryReq, terms, detectModelIds);
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
@Override
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
}
public void detectByStep(QueryReq queryReq, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) {
String text = queryReq.getQueryText();
Integer agentId = queryReq.getAgentId();
String detectSegment = text.substring(startIndex, index);
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
agentId,
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
oneDetectionMaxSize,
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults.addAll(suffixHanlpMapResults);
if (CollectionUtils.isEmpty(hanlpMapResults)) {
return;
}
// step3. merge pre/suffix result
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by similarity
hanlpMapResults = hanlpMapResults.stream()
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
>= mapperHelper.getThresholdMatch(term.getNatures()))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after isSimilarity parseResults:{}", hanlpMapResults);
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
parseResult.setOffset(offset);
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
return parseResult;
}).collect(Collectors.toCollection(LinkedHashSet::new));
// step5. take only one dimension or 10 metric/dimension value per rond.
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.collect(Collectors.toList())
.stream()
.limit(1)
.collect(Collectors.toList());
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
oneRoundResults = dimensionMetrics;
}
// step6. select mapResul in one round
selectResultInOneRound(existResults, oneRoundResults);
}
public String getMapKey(HanlpMapResult a) {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
}
}

View File

@@ -1,11 +1,13 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -39,10 +41,14 @@ public class MapperHelper {
return index;
}
public Integer getStepOffset(List<Integer> termList, Integer index) {
public Integer getStepOffset(List<Term> termList, Integer index) {
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
for (int j = 0; j < termList.size() - 1; j++) {
if (termList.get(j) <= index && termList.get(j + 1) > index) {
return termList.get(j);
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
return offsetList.get(j);
}
}
return index;

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -10,8 +9,8 @@ import java.util.Set;
/**
* match strategy
*/
public interface MatchStrategy {
public interface MatchStrategy<T> {
Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
Map<MatchText, List<T>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
}

View File

@@ -1,163 +0,0 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.compress.utils.Lists;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* match strategy implement
*/
@Service
@Slf4j
public class QueryMatchStrategy implements MatchStrategy {
@Autowired
private MapperHelper mapperHelper;
@Autowired
private OptimizationConfig optimizationConfig;
@Override
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelIds:{}", terms,
regOffsetToLength, offsetList, detectModelIds);
List<MapResult> detects = detect(queryReq, regOffsetToLength, offsetList, detectModelIds);
Map<MatchText, List<MapResult>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
private List<MapResult> detect(QueryReq queryReq, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
List<MapResult> results = Lists.newArrayList();
for (Integer index = 0; index <= text.length() - 1; ) {
Set<MapResult> mapResultRowSet = new LinkedHashSet();
for (Integer i = index; i <= text.length(); ) {
int offset = mapperHelper.getStepOffset(offsetList, index);
i = mapperHelper.getStepIndex(regOffsetToLength, i);
if (i <= text.length()) {
List<MapResult> mapResults = detectByStep(queryReq, detectModelIds, index, i, offset);
selectMapResultInOneRound(mapResultRowSet, mapResults);
}
}
index = mapperHelper.getStepIndex(regOffsetToLength, index);
results.addAll(mapResultRowSet);
}
return results;
}
private void selectMapResultInOneRound(Set<MapResult> mapResultRowSet, List<MapResult> mapResults) {
for (MapResult mapResult : mapResults) {
if (mapResultRowSet.contains(mapResult)) {
boolean isDeleted = mapResultRowSet.removeIf(
entry -> {
boolean deleted = getMapKey(mapResult).equals(getMapKey(entry))
&& entry.getDetectWord().length() < mapResult.getDetectWord().length();
if (deleted) {
log.info("deleted entry:{}", entry);
}
return deleted;
}
);
if (isDeleted) {
log.info("deleted, add mapResult:{}", mapResult);
mapResultRowSet.add(mapResult);
}
} else {
mapResultRowSet.add(mapResult);
}
}
}
private String getMapKey(MapResult a) {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
}
private List<MapResult> detectByStep(QueryReq queryReq, Set<Long> detectModelIds, Integer index, Integer i,
int offset) {
String text = queryReq.getQueryText();
Integer agentId = queryReq.getAgentId();
String detectSegment = text.substring(index, i);
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize,
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
mapResults.addAll(suffixMapResults);
if (CollectionUtils.isEmpty(mapResults)) {
return new ArrayList<>();
}
// step3. merge pre/suffix result
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by similarity
mapResults = mapResults.stream()
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
>= mapperHelper.getThresholdMatch(term.getNatures()))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after isSimilarity parseResults:{}", mapResults);
mapResults = mapResults.stream().map(parseResult -> {
parseResult.setOffset(offset);
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
return parseResult;
}).collect(Collectors.toCollection(LinkedHashSet::new));
// step5. take only one dimension or 10 metric/dimension value per rond.
List<MapResult> dimensionMetrics = mapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.collect(Collectors.toList())
.stream()
.limit(1)
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
return dimensionMetrics;
} else {
return mapResults.stream().limit(optimizationConfig.getOneDetectionSize()).collect(Collectors.toList());
}
}
}

View File

@@ -4,7 +4,7 @@ import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.List;
import java.util.Map;
@@ -20,17 +20,15 @@ import org.springframework.stereotype.Service;
* match strategy implement
*/
@Service
public class SearchMatchStrategy implements MatchStrategy {
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
private static final int SEARCH_SIZE = 3;
@Override
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> originals, Set<Long> detectModelIds) {
public Map<MatchText, List<HanlpMapResult>> match(QueryReq queryReq, List<Term> originals,
Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
Map<Integer, Integer> regOffsetToLength = originals.stream()
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
.collect(Collectors.toMap(Term::getOffset, value -> value.word.length(),
(value1, value2) -> value2));
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
List<Integer> detectIndexList = Lists.newArrayList();
@@ -46,19 +44,19 @@ public class SearchMatchStrategy implements MatchStrategy {
index++;
}
}
Map<MatchText, List<MapResult>> regTextMap = new ConcurrentHashMap<>();
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
detectIndexList.stream().parallel().forEach(detectIndex -> {
String regText = text.substring(0, detectIndex);
String detectSegment = text.substring(detectIndex);
if (StringUtils.isNotEmpty(detectSegment)) {
List<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE,
queryReq.getAgentId(), detectModelIds);
mapResults.addAll(suffixMapResults);
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
mapResults = mapResults.stream().filter(entry -> {
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
List<String> natures = entry.getNatures().stream()
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
.collect(Collectors.toList());
@@ -71,10 +69,27 @@ public class SearchMatchStrategy implements MatchStrategy {
.regText(regText)
.detectSegment(detectSegment)
.build();
regTextMap.put(matchText, mapResults);
regTextMap.put(matchText, hanlpMapResults);
}
}
);
return regTextMap;
}
@Override
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
return false;
}
@Override
public String getMapKey(HanlpMapResult a) {
return null;
}
@Override
public void detectByStep(QueryReq queryReq, Set<HanlpMapResult> results, Set<Long> detectModelIds,
Integer startIndex,
Integer i, int offset) {
}
}

View File

@@ -57,7 +57,7 @@ public class SimilarMetricExecuteResponder implements ExecuteResponder {
metric.setOrder(metricOrder++);
}
for (Retrieval retrieval : retrievals) {
if (!metricIds.contains(retrieval.getId())) {
if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) {
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
SchemaElement.class);
if (retrieval.getMetadata().containsKey("modelId")) {

View File

@@ -47,7 +47,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
@@ -628,10 +628,10 @@ public class QueryServiceImpl implements QueryService {
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
}
//search from prefixSearch
List<MapResult> mapResultList = SearchService.prefixSearch(dimensionValueReq.getValue(),
List<HanlpMapResult> hanlpMapResultList = SearchService.prefixSearch(dimensionValueReq.getValue(),
2000, dimensionValueReq.getAgentId(), detectModelIds);
HanlpHelper.transLetterOriginal(mapResultList);
return mapResultList.stream()
HanlpHelper.transLetterOriginal(hanlpMapResultList);
return hanlpMapResultList.stream()
.filter(o -> {
for (String nature : o.getNatures()) {
Long elementID = NatureHelper.getElementID(nature);

View File

@@ -24,7 +24,7 @@ import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.SearchService;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
@@ -94,11 +94,12 @@ public class SearchServiceImpl implements SearchService {
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq);
Map<MatchText, List<MapResult>> regTextMap = searchMatchStrategy.match(queryReq, originals, detectModelIds);
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryReq, originals, detectModelIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
// 4.get the most matching data
Optional<Entry<MatchText, List<MapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
Optional<Entry<MatchText, List<HanlpMapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
.stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
.reduce((entry1, entry2) ->
@@ -109,7 +110,7 @@ public class SearchServiceImpl implements SearchService {
if (!mostSimilarSearchResult.isPresent()) {
return Lists.newArrayList();
}
Map.Entry<MatchText, List<MapResult>> searchTextEntry = mostSimilarSearchResult.get();
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry = mostSimilarSearchResult.get();
log.info("searchTextEntry:{},queryReq:{}", searchTextEntry, queryReq);
Set<SearchResult> searchResults = new LinkedHashSet();
@@ -275,9 +276,9 @@ public class SearchServiceImpl implements SearchService {
* @param recommendTextListEntry
* @return
*/
private Map<String, String> getNatureToNameMap(Map.Entry<MatchText, List<MapResult>> recommendTextListEntry,
private Map<String, String> getNatureToNameMap(Map.Entry<MatchText, List<HanlpMapResult>> recommendTextListEntry,
Set<Long> possibleModels) {
List<MapResult> recommendValues = recommendTextListEntry.getValue();
List<HanlpMapResult> recommendValues = recommendTextListEntry.getValue();
return recommendValues.stream()
.flatMap(entry -> entry.getNatures().stream()
.filter(nature -> {
@@ -288,26 +289,25 @@ public class SearchServiceImpl implements SearchService {
return possibleModels.contains(model);
})
.map(nature -> {
DictWord posDO = new DictWord();
posDO.setWord(entry.getName());
posDO.setNature(nature);
return posDO;
}
)).sorted(Comparator.comparingInt(a -> a.getWord().length()))
DictWord posDO = new DictWord();
posDO.setWord(entry.getName());
posDO.setNature(nature);
return posDO;
})).sorted(Comparator.comparingInt(a -> a.getWord().length()))
.collect(Collectors.toMap(DictWord::getNature, DictWord::getWord, (value1, value2) -> value1,
LinkedHashMap::new));
}
private boolean searchMetricAndDimension(Set<Long> possibleModels, Map<Long, String> modelToName,
Map.Entry<MatchText, List<MapResult>> searchTextEntry, Set<SearchResult> searchResults) {
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry, Set<SearchResult> searchResults) {
boolean existMetric = false;
log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry);
MatchText matchText = searchTextEntry.getKey();
List<MapResult> mapResults = searchTextEntry.getValue();
List<HanlpMapResult> hanlpMapResults = searchTextEntry.getValue();
for (MapResult mapResult : mapResults) {
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
List<ModelWithSemanticType> dimensionMetricClassIds = mapResult.getNatures().stream()
List<ModelWithSemanticType> dimensionMetricClassIds = hanlpMapResult.getNatures().stream()
.map(nature -> new ModelWithSemanticType(NatureHelper.getModelId(nature),
NatureHelper.convertToElementType(nature)))
.filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList());
@@ -322,8 +322,8 @@ public class SearchServiceImpl implements SearchService {
SearchResult searchResult = SearchResult.builder()
.modelId(modelId)
.modelName(modelToName.get(modelId))
.recommend(matchText.getRegText() + mapResult.getName())
.subRecommend(mapResult.getName())
.recommend(matchText.getRegText() + hanlpMapResult.getName())
.subRecommend(hanlpMapResult.getName())
.schemaElementType(semanticType)
.build();
//visibility to filter metrics
@@ -332,13 +332,13 @@ public class SearchServiceImpl implements SearchService {
visibility = configService.getVisibilityByModelId(modelId);
caffeineCache.put(modelId, visibility);
}
if (!visibility.getBlackMetricNameList().contains(mapResult.getName())
&& !visibility.getBlackDimNameList().contains(mapResult.getName())) {
if (!visibility.getBlackMetricNameList().contains(hanlpMapResult.getName())
&& !visibility.getBlackDimNameList().contains(hanlpMapResult.getName())) {
searchResults.add(searchResult);
}
}
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", mapResult, dimensionMetricClassIds,
possibleModels);
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", hanlpMapResult,
dimensionMetricClassIds, possibleModels);
}
log.info("searchMetricAndDimension searchResults:{}", searchResults);
return existMetric;

View File

@@ -6,7 +6,7 @@ import org.junit.jupiter.api.Test;
/**
* MatchStrategyImplTest
*/
class QueryMatchStrategyTest extends ContextTest {
class HanlpMatchStrategyTest extends ContextTest {
@Test
void match() {

View File

@@ -0,0 +1,34 @@
package com.tencent.supersonic.knowledge.dictionary;
import com.google.common.base.Objects;
import java.util.Map;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class EmbeddingResult extends MapResult {
private String id;
private double distance;
private Map<String, String> metadata;
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
EmbeddingResult that = (EmbeddingResult) o;
return Objects.equal(id, that.id);
}
@Override
public int hashCode() {
return Objects.hashCode(id);
}
}

View File

@@ -0,0 +1,44 @@
package com.tencent.supersonic.knowledge.dictionary;
import com.google.common.base.Objects;
import java.util.List;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class HanlpMapResult extends MapResult {
private List<String> natures;
private int offset = 0;
private double similarity;
public HanlpMapResult(String name, List<String> natures, String detectWord) {
this.name = name;
this.natures = natures;
this.detectWord = detectWord;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
HanlpMapResult hanlpMapResult = (HanlpMapResult) o;
return Objects.equal(name, hanlpMapResult.name) && Objects.equal(natures, hanlpMapResult.natures);
}
@Override
public int hashCode() {
return Objects.hashCode(name, natures);
}
public void setOffset(int offset) {
this.offset = offset;
}
}

View File

@@ -1,8 +1,6 @@
package com.tencent.supersonic.knowledge.dictionary;
import com.google.common.base.Objects;
import java.io.Serializable;
import java.util.List;
import lombok.Data;
import lombok.ToString;
@@ -10,43 +8,6 @@ import lombok.ToString;
@ToString
public class MapResult implements Serializable {
private String name;
private List<String> natures;
private int offset = 0;
private double similarity;
private String detectWord;
public MapResult() {
}
public MapResult(String name, List<String> natures, String detectWord) {
this.name = name;
this.natures = natures;
this.detectWord = detectWord;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
MapResult mapResult = (MapResult) o;
return Objects.equal(name, mapResult.name) && Objects.equal(natures, mapResult.natures);
}
@Override
public int hashCode() {
return Objects.hashCode(name, natures);
}
public void setOffset(int offset) {
this.offset = offset;
}
protected String name;
protected String detectWord;
}

View File

@@ -7,7 +7,7 @@ import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.DictionaryAttributeUtil;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@@ -38,17 +38,17 @@ public class SearchService {
* @param key
* @return
*/
public static List<MapResult> prefixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
public static List<HanlpMapResult> prefixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
return prefixSearch(key, limit, agentId, trie, detectModelIds);
}
public static List<MapResult> prefixSearch(String key, int limit, Integer agentId, BinTrie<List<String>> binTrie,
Set<Long> detectModelIds) {
public static List<HanlpMapResult> prefixSearch(String key, int limit, Integer agentId,
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
return result.stream().map(
entry -> {
String name = entry.getKey().replace("#", " ");
return new MapResult(name, entry.getValue(), key);
return new HanlpMapResult(name, entry.getValue(), key);
}
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.limit(SEARCH_SIZE)
@@ -60,13 +60,13 @@ public class SearchService {
* @param key
* @return
*/
public static List<MapResult> suffixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
public static List<HanlpMapResult> suffixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
String reverseDetectSegment = StringUtils.reverse(key);
return suffixSearch(reverseDetectSegment, limit, agentId, suffixTrie, detectModelIds);
}
public static List<MapResult> suffixSearch(String key, int limit, Integer agentId, BinTrie<List<String>> binTrie,
Set<Long> detectModelIds) {
public static List<HanlpMapResult> suffixSearch(String key, int limit, Integer agentId,
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
return result.stream().map(
entry -> {
@@ -75,7 +75,7 @@ public class SearchService {
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
.collect(Collectors.toList());
name = StringUtils.reverse(name);
return new MapResult(name, natures, key);
return new HanlpMapResult(name, natures, key);
}
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.limit(SEARCH_SIZE)

View File

@@ -9,17 +9,16 @@ import com.hankcs.hanlp.seg.Segment;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.HadoopFileIOAdapter;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.HadoopFileIOAdapter;
import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
@@ -186,11 +185,11 @@ public class HanlpHelper {
}
}
public static void transLetterOriginal(List<MapResult> mapResults) {
public static <T extends MapResult> void transLetterOriginal(List<T> mapResults) {
if (CollectionUtils.isEmpty(mapResults)) {
return;
}
for (MapResult mapResult : mapResults) {
for (T mapResult : mapResults) {
if (MultiCustomDictionary.isLowerLetter(mapResult.getName())) {
if (CustomDictionary.contains(mapResult.getName())) {
CoreDictionary.Attribute attribute = CustomDictionary.get(mapResult.getName());

View File

@@ -1,18 +1,28 @@
package com.tencent.supersonic.common.util.embedding;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import lombok.Data;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
@Data
public class Retrieval {
private Long id;
protected String id;
private double distance;
protected double distance;
private String query;
protected String query;
private Map<String, String> metadata;
protected Map<String, String> metadata;
public static Long getLongId(String id) {
if (StringUtils.isBlank(id)) {
return null;
}
String[] split = id.split(DictWordType.NATURE_SPILT);
return Long.parseLong(split[0]);
}
}

View File

@@ -264,7 +264,8 @@ class SqlParserAddHelperTest {
+ "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10",
replaceSql);
sql = "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11' and count(DISTINCT uv) >1 "
sql = "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11'"
+ " and count(DISTINCT uv) >1 "
+ "GROUP BY department order by count(DISTINCT uv) desc limit 10";
replaceSql = SqlParserAddHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
@@ -290,7 +291,8 @@ class SqlParserAddHelperTest {
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = '2023-09-11' AND count(DISTINCT uv) > 1 "
"SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = "
+ "'2023-09-11' AND count(DISTINCT uv) > 1 "
+ "AND department = 'HR' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10",
replaceSql);
@@ -300,8 +302,10 @@ class SqlParserAddHelperTest {
replaceSql = SqlParserAddHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, count(DISTINCT uv) FROM t_1 WHERE (count(DISTINCT uv) > 1 AND department = 'HR') AND "
+ "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10",
"SELECT department, count(DISTINCT uv) FROM t_1 WHERE (count(DISTINCT uv) > "
+ "1 AND department = 'HR') AND "
+ "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY "
+ "count(DISTINCT uv) DESC LIMIT 10",
replaceSql);
sql = "select department, count(DISTINCT uv) as uv from t_1 where sys_imp_date = '2023-09-11' GROUP BY "

View File

@@ -1,4 +1,5 @@
com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.mapper.EmbeddingMapper, \
com.tencent.supersonic.chat.mapper.HanlpDictMapper, \
com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \
com.tencent.supersonic.chat.mapper.QueryFilterMapper, \

View File

@@ -1,4 +1,5 @@
com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.mapper.EmbeddingMapper, \
com.tencent.supersonic.chat.mapper.HanlpDictMapper, \
com.tencent.supersonic.chat.mapper.FuzzyNameMapper, \
com.tencent.supersonic.chat.mapper.QueryFilterMapper, \

View File

@@ -2,18 +2,18 @@ package com.tencent.supersonic.semantic.model.domain.listener;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Component
@Slf4j
@@ -30,15 +30,17 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
return;
}
List<EmbeddingQuery> embeddingQueries = event.getDataItems()
.stream().filter(dataItem -> dataItem.getType().equals(TypeEnums.METRIC)).map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(dataItem.getId().toString());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
.stream()
.map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId().toString() + DictWordType.NATURE_SPILT + dataItem.getType().getName());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
if (CollectionUtils.isEmpty(embeddingQueries)) {
return;
}