diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index 38ac2908e..1b49507f9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -69,14 +69,9 @@ public abstract class BaseMatchStrategy implements MatchStrategy { } startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); } - detectByBatch(chatQueryContext, results, detectDataSetIds, detectSegments); return new ArrayList<>(results); } - protected void detectByBatch(ChatQueryContext chatQueryContext, Set results, Set detectDataSetIds, - Set detectSegments) { - } - public Map getRegOffsetToLength(List terms) { return terms.stream().sorted(Comparator.comparing(S2Term::length)) .collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(), diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index b3642d85e..d9b492d66 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.headless.api.pojo.response.S2Term; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; @@ -15,7 +16,9 @@ import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import java.util.ArrayList; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -55,17 +58,34 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { } @Override - protected void detectByBatch(ChatQueryContext chatQueryContext, Set results, - Set detectDataSetIds, Set detectSegments) { - int embeddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN)); - int embeddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX)); + public List detect(ChatQueryContext chatQueryContext, List terms, + Set detectDataSetIds) { + String text = chatQueryContext.getQueryText(); + Set detectSegments = new HashSet<>(); + + int embeddingTextSize = Integer.valueOf( + mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE)); + + int embeddingTextStep = Integer.valueOf( + mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP)); + + for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) { + int endIndex = Math.min(startIndex + embeddingTextSize, text.length()); + String detectSegment = text.substring(startIndex, endIndex).trim(); + detectSegments.add(detectSegment); + } + Set results = detectByBatch(chatQueryContext, detectDataSetIds, detectSegments); + return new ArrayList<>(results); + } + + protected Set detectByBatch(ChatQueryContext chatQueryContext, + Set detectDataSetIds, Set detectSegments) { + Set results = new HashSet<>(); int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); List queryTextsList = detectSegments.stream() .map(detectSegment -> detectSegment.trim()) - .filter(detectSegment -> StringUtils.isNotBlank(detectSegment) - && detectSegment.length() >= embeddingMapperMin - && detectSegment.length() <= embeddingMapperMax) + .filter(detectSegment -> StringUtils.isNotBlank(detectSegment)) .collect(Collectors.toList()); List> queryTextsSubList = Lists.partition(queryTextsList, @@ -74,6 +94,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { for (List queryTextsSub : queryTextsSubList) { detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext); } + return results; } private void detectByQueryTextsSub(Set results, Set detectDataSetIds, diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java index 7ed39deac..f5a7af8af 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java @@ -49,16 +49,16 @@ public class MapperConfig extends ParameterConfig { "维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"); - public static final Parameter EMBEDDING_MAPPER_MIN = - new Parameter("s2.mapper.embedding.word.min", "4", - "用于向量召回最小的文本长度", - "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", + public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE = + new Parameter("s2.mapper.embedding.word.size", "4", + "用于向量召回文本长度", + "为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置"); - public static final Parameter EMBEDDING_MAPPER_MAX = - new Parameter("s2.mapper.embedding.word.max", "5", - "用于向量召回最大的文本长度", - "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", + public static final Parameter EMBEDDING_MAPPER_TEXT_STEP = + new Parameter("s2.mapper.embedding.word.step", "3", + "向量召回文本每步长度", + "为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_BATCH = diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java index b7a53ca21..ac79deb40 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java @@ -60,7 +60,6 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay)); @@ -80,7 +79,6 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricGroupByQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); @@ -101,7 +99,6 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); List list = new ArrayList<>(); @@ -128,9 +125,9 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricTopNQuery.QUERY_MODE); expectedParseInfo.setAggType(SUM); + expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户")); expectedParseInfo.setDateInfo(DataUtils.getDateConf(3, DateConf.DateMode.RECENT, "DAY")); @@ -149,7 +146,6 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricGroupByQuery.QUERY_MODE); expectedParseInfo.setAggType(SUM); - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); @@ -175,7 +171,6 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",