mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(chat) Improve vector recall performance. (#1458)
This commit is contained in:
@@ -69,14 +69,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
detectByBatch(chatQueryContext, results, detectDataSetIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<T> results, Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||
return terms.stream().sorted(Comparator.comparing(S2Term::length))
|
||||
.collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(),
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
@@ -15,7 +16,9 @@ import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@@ -55,17 +58,34 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||
int embeddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
|
||||
int embeddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
|
||||
public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = chatQueryContext.getQueryText();
|
||||
Set<String> detectSegments = new HashSet<>();
|
||||
|
||||
int embeddingTextSize = Integer.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE));
|
||||
|
||||
int embeddingTextStep = Integer.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP));
|
||||
|
||||
for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) {
|
||||
int endIndex = Math.min(startIndex + embeddingTextSize, text.length());
|
||||
String detectSegment = text.substring(startIndex, endIndex).trim();
|
||||
detectSegments.add(detectSegment);
|
||||
}
|
||||
Set<EmbeddingResult> results = detectByBatch(chatQueryContext, detectDataSetIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected Set<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||
Set<EmbeddingResult> results = new HashSet<>();
|
||||
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||
&& detectSegment.length() >= embeddingMapperMin
|
||||
&& detectSegment.length() <= embeddingMapperMax)
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||
@@ -74,6 +94,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
|
||||
@@ -49,16 +49,16 @@ public class MapperConfig extends ParameterConfig {
|
||||
"维度值相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_MIN =
|
||||
new Parameter("s2.mapper.embedding.word.min", "4",
|
||||
"用于向量召回最小的文本长度",
|
||||
"为提高向量召回效率, 小于该长度的文本不进行向量语义召回",
|
||||
public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
|
||||
new Parameter("s2.mapper.embedding.word.size", "4",
|
||||
"用于向量召回文本长度",
|
||||
"为提高向量召回效率, 按指定长度进行向量语义召回",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_MAX =
|
||||
new Parameter("s2.mapper.embedding.word.max", "5",
|
||||
"用于向量召回最大的文本长度",
|
||||
"为提高向量召回效率, 大于该长度的文本不进行向量语义召回",
|
||||
public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
|
||||
new Parameter("s2.mapper.embedding.word.step", "3",
|
||||
"向量召回文本每步长度",
|
||||
"为提高向量召回效率, 按指定每步长度进行召回",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_BATCH =
|
||||
|
||||
Reference in New Issue
Block a user