mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
(improvement)(chat) Performance optimization for Embedding Mapper, adding maximum and minimum text detection configurations (#335)
This commit is contained in:
@@ -42,18 +42,21 @@ public class OptimizationConfig {
|
|||||||
@Value("${user.s2ql.switch:false}")
|
@Value("${user.s2ql.switch:false}")
|
||||||
private boolean useS2qlSwitch;
|
private boolean useS2qlSwitch;
|
||||||
|
|
||||||
@Value("${embedding.mapper.word.min:2}")
|
@Value("${embedding.mapper.word.min:3}")
|
||||||
private int embeddingMapperWordMin;
|
private int embeddingMapperWordMin;
|
||||||
|
|
||||||
@Value("${embedding.mapper.number:10}")
|
@Value("${embedding.mapper.word.max:5}")
|
||||||
|
private int embeddingMapperWordMax;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.batch:50}")
|
||||||
|
private int embeddingMapperBatch;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.number:5}")
|
||||||
private int embeddingMapperNumber;
|
private int embeddingMapperNumber;
|
||||||
|
|
||||||
@Value("${embedding.mapper.round.number:10}")
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
private int embeddingMapperRoundNumber;
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
@Value("${embedding.mapper.sum.number:10}")
|
@Value("${embedding.mapper.distance.threshold:0.52}")
|
||||||
private int embeddingMapperSumNumber;
|
|
||||||
|
|
||||||
@Value("${embedding.mapper.distance.threshold:0.3}")
|
|
||||||
private Double embeddingMapperDistanceThreshold;
|
private Double embeddingMapperDistanceThreshold;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,20 +50,30 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
String text = queryContext.getRequest().getQueryText();
|
String text = queryContext.getRequest().getQueryText();
|
||||||
Set<T> results = new HashSet<>();
|
Set<T> results = new HashSet<>();
|
||||||
|
|
||||||
for (Integer index = 0; index <= text.length() - 1; ) {
|
Set<String> detectSegments = new HashSet<>();
|
||||||
|
|
||||||
for (Integer i = index; i <= text.length(); ) {
|
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
|
||||||
int offset = mapperHelper.getStepOffset(terms, index);
|
|
||||||
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
for (Integer index = startIndex; index <= text.length(); ) {
|
||||||
if (i <= text.length()) {
|
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||||
detectByStep(queryContext, results, detectModelIds, index, i, offset);
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
|
if (index <= text.length()) {
|
||||||
|
String detectSegment = text.substring(startIndex, index);
|
||||||
|
detectSegments.add(detectSegment);
|
||||||
|
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
}
|
}
|
||||||
|
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
||||||
|
Set<String> detectSegments) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
||||||
return terms.stream().sorted(Comparator.comparing(Term::length))
|
return terms.stream().sorted(Comparator.comparing(Term::length))
|
||||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
|||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -47,19 +46,32 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
|
||||||
Integer startIndex, Integer index, int offset) {
|
@Override
|
||||||
QueryReq queryReq = queryContext.getRequest();
|
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||||
String detectSegment = queryReq.getQueryText().substring(startIndex, index);
|
Set<String> detectSegments) {
|
||||||
// step1. build query params
|
|
||||||
if (StringUtils.isBlank(detectSegment)
|
List<String> queryTextsList = detectSegments.stream()
|
||||||
|| detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMin()) {
|
.map(detectSegment -> detectSegment.trim())
|
||||||
return;
|
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||||
|
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
|
||||||
|
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||||
|
optimizationConfig.getEmbeddingMapperBatch());
|
||||||
|
|
||||||
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
|
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||||
|
List<String> queryTextsSub) {
|
||||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||||
Map<String, String> filterCondition = null;
|
Map<String, String> filterCondition = null;
|
||||||
|
// step1. build query params
|
||||||
// if only one modelId, add to filterCondition
|
// if only one modelId, add to filterCondition
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
||||||
filterCondition = new HashMap<>();
|
filterCondition = new HashMap<>();
|
||||||
@@ -67,7 +79,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
.queryTextsList(Collections.singletonList(detectSegment))
|
.queryTextsList(queryTextsSub)
|
||||||
.filterCondition(filterCondition)
|
.filterCondition(filterCondition)
|
||||||
.queryEmbeddings(null)
|
.queryEmbeddings(null)
|
||||||
.build();
|
.build();
|
||||||
@@ -101,18 +113,24 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
.map(retrieval -> {
|
.map(retrieval -> {
|
||||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||||
embeddingResult.setDetectWord(detectSegment);
|
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||||
embeddingResult.setName(retrieval.getQuery());
|
embeddingResult.setName(retrieval.getQuery());
|
||||||
return embeddingResult;
|
return embeddingResult;
|
||||||
}))
|
}))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
// step4. select mapResul in one round
|
// step4. select mapResul in one round
|
||||||
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber();
|
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
|
||||||
List<EmbeddingResult> oneRoundResults = collect.stream()
|
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||||
.limit(roundNumber)
|
.limit(roundNumber)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
selectResultInOneRound(existResults, oneRoundResults);
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex, Integer index, int offset) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user