diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java index ab37a01ad..772870053 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java @@ -42,18 +42,21 @@ public class OptimizationConfig { @Value("${user.s2ql.switch:false}") private boolean useS2qlSwitch; - @Value("${embedding.mapper.word.min:2}") + @Value("${embedding.mapper.word.min:3}") private int embeddingMapperWordMin; - @Value("${embedding.mapper.number:10}") + @Value("${embedding.mapper.word.max:5}") + private int embeddingMapperWordMax; + + @Value("${embedding.mapper.batch:50}") + private int embeddingMapperBatch; + + @Value("${embedding.mapper.number:5}") private int embeddingMapperNumber; @Value("${embedding.mapper.round.number:10}") private int embeddingMapperRoundNumber; - @Value("${embedding.mapper.sum.number:10}") - private int embeddingMapperSumNumber; - - @Value("${embedding.mapper.distance.threshold:0.3}") + @Value("${embedding.mapper.distance.threshold:0.52}") private Double embeddingMapperDistanceThreshold; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java index 693c29846..6be372cf4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMatchStrategy.java @@ -50,20 +50,30 @@ public abstract class BaseMatchStrategy implements MatchStrategy { String text = queryContext.getRequest().getQueryText(); Set results = new HashSet<>(); - for (Integer index = 0; index <= text.length() - 1; ) { + Set detectSegments = new HashSet<>(); - for (Integer i = index; i <= text.length(); ) { - int offset = mapperHelper.getStepOffset(terms, index); - i = mapperHelper.getStepIndex(regOffsetToLength, i); - if (i <= text.length()) { - detectByStep(queryContext, results, detectModelIds, index, i, offset); + for (Integer startIndex = 0; startIndex <= text.length() - 1; ) { + + for (Integer index = startIndex; index <= text.length(); ) { + int offset = mapperHelper.getStepOffset(terms, startIndex); + index = mapperHelper.getStepIndex(regOffsetToLength, index); + if (index <= text.length()) { + String detectSegment = text.substring(startIndex, index); + detectSegments.add(detectSegment); + detectByStep(queryContext, results, detectModelIds, startIndex, index, offset); } } - index = mapperHelper.getStepIndex(regOffsetToLength, index); + startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); } + detectByBatch(queryContext, results, detectModelIds, detectSegments); return new ArrayList<>(results); } + protected void detectByBatch(QueryContext queryContext, Set results, Set detectModelIds, + Set detectSegments) { + return; + } + public Map getRegOffsetToLength(List terms) { return terms.stream().sorted(Comparator.comparing(Term::length)) .collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2)); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java index 5e2df84f0..130df3a2e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.chat.mapper; +import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; @@ -10,7 +10,6 @@ import com.tencent.supersonic.common.util.embedding.RetrieveQuery; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult; import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener; -import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -47,19 +46,32 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { return a.getName() + Constants.UNDERLINE + a.getId(); } - public void detectByStep(QueryContext queryContext, Set existResults, Set detectModelIds, - Integer startIndex, Integer index, int offset) { - QueryReq queryReq = queryContext.getRequest(); - String detectSegment = queryReq.getQueryText().substring(startIndex, index); - // step1. build query params - if (StringUtils.isBlank(detectSegment) - || detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMin()) { - return; + + @Override + protected void detectByBatch(QueryContext queryContext, Set results, Set detectModelIds, + Set detectSegments) { + + List queryTextsList = detectSegments.stream() + .map(detectSegment -> detectSegment.trim()) + .filter(detectSegment -> StringUtils.isNotBlank(detectSegment) + && detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin() + && detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax()) + .collect(Collectors.toList()); + + List> queryTextsSubList = Lists.partition(queryTextsList, + optimizationConfig.getEmbeddingMapperBatch()); + + for (List queryTextsSub : queryTextsSubList) { + detectByQueryTextsSub(results, detectModelIds, queryTextsSub); } + } + + private void detectByQueryTextsSub(Set results, Set detectModelIds, + List queryTextsSub) { int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber(); Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold(); Map filterCondition = null; - + // step1. build query params // if only one modelId, add to filterCondition if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) { filterCondition = new HashMap<>(); @@ -67,7 +79,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { } RetrieveQuery retrieveQuery = RetrieveQuery.builder() - .queryTextsList(Collections.singletonList(detectSegment)) + .queryTextsList(queryTextsSub) .filterCondition(filterCondition) .queryEmbeddings(null) .build(); @@ -101,18 +113,24 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { .map(retrieval -> { EmbeddingResult embeddingResult = new EmbeddingResult(); BeanUtils.copyProperties(retrieval, embeddingResult); - embeddingResult.setDetectWord(detectSegment); + embeddingResult.setDetectWord(retrieveQueryResult.getQuery()); embeddingResult.setName(retrieval.getQuery()); return embeddingResult; })) .collect(Collectors.toList()); // step4. select mapResul in one round - int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber(); + int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size(); List oneRoundResults = collect.stream() .sorted(Comparator.comparingDouble(EmbeddingResult::getDistance)) .limit(roundNumber) .collect(Collectors.toList()); - selectResultInOneRound(existResults, oneRoundResults); + selectResultInOneRound(results, oneRoundResults); + } + + @Override + public void detectByStep(QueryContext queryContext, Set existResults, Set detectModelIds, + Integer startIndex, Integer index, int offset) { + return; } }