(improvement)(chat) Improve vector recall performance. (#1458)

This commit is contained in:
lexluo09
2024-07-25 22:19:35 +08:00
committed by GitHub
parent c8df102402
commit ae34c15c95
4 changed files with 37 additions and 26 deletions

View File

@@ -69,14 +69,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
} }
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
} }
detectByBatch(chatQueryContext, results, detectDataSetIds, detectSegments);
return new ArrayList<>(results); return new ArrayList<>(results);
} }
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<T> results, Set<Long> detectDataSetIds,
Set<String> detectSegments) {
}
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) { public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
return terms.stream().sorted(Comparator.comparing(S2Term::length)) return terms.stream().sorted(Comparator.comparing(S2Term::length))
.collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(), .collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(),

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
@@ -15,7 +16,9 @@ import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@@ -55,17 +58,34 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
} }
@Override @Override
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results, public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds, Set<String> detectSegments) { Set<Long> detectDataSetIds) {
int embeddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN)); String text = chatQueryContext.getQueryText();
int embeddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX)); Set<String> detectSegments = new HashSet<>();
int embeddingTextSize = Integer.valueOf(
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE));
int embeddingTextStep = Integer.valueOf(
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP));
for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) {
int endIndex = Math.min(startIndex + embeddingTextSize, text.length());
String detectSegment = text.substring(startIndex, endIndex).trim();
detectSegments.add(detectSegment);
}
Set<EmbeddingResult> results = detectByBatch(chatQueryContext, detectDataSetIds, detectSegments);
return new ArrayList<>(results);
}
protected Set<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments) {
Set<EmbeddingResult> results = new HashSet<>();
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
List<String> queryTextsList = detectSegments.stream() List<String> queryTextsList = detectSegments.stream()
.map(detectSegment -> detectSegment.trim()) .map(detectSegment -> detectSegment.trim())
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment) .filter(detectSegment -> StringUtils.isNotBlank(detectSegment))
&& detectSegment.length() >= embeddingMapperMin
&& detectSegment.length() <= embeddingMapperMax)
.collect(Collectors.toList()); .collect(Collectors.toList());
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList, List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
@@ -74,6 +94,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
for (List<String> queryTextsSub : queryTextsSubList) { for (List<String> queryTextsSub : queryTextsSubList) {
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext); detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
} }
return results;
} }
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds, private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,

View File

@@ -49,16 +49,16 @@ public class MapperConfig extends ParameterConfig {
"维度值相似度阈值在动态调整中的最低值", "维度值相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置"); "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_MIN = public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
new Parameter("s2.mapper.embedding.word.min", "4", new Parameter("s2.mapper.embedding.word.size", "4",
"用于向量召回最小的文本长度", "用于向量召回文本长度",
"为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "为提高向量召回效率, 按指定长度进行向量语义召回",
"number", "Mapper相关配置"); "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_MAX = public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
new Parameter("s2.mapper.embedding.word.max", "5", new Parameter("s2.mapper.embedding.word.step", "3",
"用于向量召回最大的文本长度", "向量召回文本每步长度",
"为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "为提高向量召回效率, 按指定每步长度进行召回",
"number", "Mapper相关配置"); "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_BATCH = public static final Parameter EMBEDDING_MAPPER_BATCH =

View File

@@ -60,7 +60,6 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE); expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay)); expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
@@ -80,7 +79,6 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricGroupByQuery.QUERY_MODE); expectedResult.setQueryMode(MetricGroupByQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
@@ -101,7 +99,6 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
@@ -128,9 +125,9 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricTopNQuery.QUERY_MODE); expectedResult.setQueryMode(MetricTopNQuery.QUERY_MODE);
expectedParseInfo.setAggType(SUM); expectedParseInfo.setAggType(SUM);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(3, DateConf.DateMode.RECENT, "DAY")); expectedParseInfo.setDateInfo(DataUtils.getDateConf(3, DateConf.DateMode.RECENT, "DAY"));
@@ -149,7 +146,6 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricGroupByQuery.QUERY_MODE); expectedResult.setQueryMode(MetricGroupByQuery.QUERY_MODE);
expectedParseInfo.setAggType(SUM); expectedParseInfo.setAggType(SUM);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
@@ -175,7 +171,6 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",