diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ThreadPoolConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ThreadPoolConfig.java index 83d04ef44..2971a4f85 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ThreadPoolConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/ThreadPoolConfig.java @@ -14,23 +14,25 @@ public class ThreadPoolConfig { @Bean("commonExecutor") public ThreadPoolExecutor getCommonExecutor() { return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, - new LinkedBlockingQueue(1024), + new LinkedBlockingQueue<>(1024), new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(), new ThreadPoolExecutor.CallerRunsPolicy()); } @Bean("mapExecutor") public ThreadPoolExecutor getMapExecutor() { - return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, - new LinkedBlockingQueue(8192), + return new ThreadPoolExecutor( + 8, 16, 60 * 3, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(), new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(), - new ThreadPoolExecutor.CallerRunsPolicy()); + new ThreadPoolExecutor.CallerRunsPolicy() + ); } @Bean("chatExecutor") public ThreadPoolExecutor getChatExecutor() { return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, - new LinkedBlockingQueue(1024), + new LinkedBlockingQueue<>(1024), new ThreadFactoryBuilder().setNameFormat("supersonic-chat-pool-").build(), new ThreadPoolExecutor.CallerRunsPolicy()); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index a99879e3c..8c4a35741 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.chat.mapper; +import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; @@ -7,6 +8,7 @@ import com.tencent.supersonic.headless.chat.knowledge.MapResult; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.HashMap; @@ -14,10 +16,15 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.concurrent.Callable; @Service @Slf4j public abstract class BaseMatchStrategy implements MatchStrategy { + + @Autowired + protected ThreadPoolConfig threadPoolConfig; + @Override public Map> match(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { @@ -63,6 +70,14 @@ public abstract class BaseMatchStrategy implements MatchStr } } + protected void executeTasks(List> tasks) { + try { + threadPoolConfig.getMapExecutor().invokeAll(tasks); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Task execution interrupted", e); + } + } public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) { if (MapModeEnum.STRICT.equals(mapModeEnum)) { return 1.0d; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index 1b73c5cc8..861655dac 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.mapper; import com.google.common.collect.Lists; +import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; @@ -16,10 +17,11 @@ import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.Comparator; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER; @@ -36,11 +38,13 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy @Autowired private MetaEmbeddingService metaEmbeddingService; + @Autowired + protected ThreadPoolConfig threadPoolConfig; @Override public List detectByBatch(ChatQueryContext chatQueryContext, - Set detectDataSetIds, Set detectSegments) { - Set results = new HashSet<>(); + Set detectDataSetIds, Set detectSegments) { + Set results = ConcurrentHashMap.newKeySet(); int embeddingMapperBatch = Integer .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); @@ -52,16 +56,28 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy List> queryTextsSubList = Lists.partition(queryTextsList, embeddingMapperBatch); + List> tasks = new ArrayList<>(); for (List queryTextsSub : queryTextsSubList) { - List oneRoundResults = - detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext); - selectResultInOneRound(results, oneRoundResults); + tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results)); } + executeTasks(tasks); return new ArrayList<>(results); } + private Callable createTask(ChatQueryContext chatQueryContext, Set detectDataSetIds, + List queryTextsSub, Set results) { + return () -> { + List oneRoundResults = + detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext); + synchronized (results) { + selectResultInOneRound(results, oneRoundResults); + } + return null; + }; + } + private List detectByQueryTextsSub(Set detectDataSetIds, - List queryTextsSub, ChatQueryContext chatQueryContext) { + List queryTextsSub, ChatQueryContext chatQueryContext) { Map> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds(); double threshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD)); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java index 770a7f097..4ca18e142 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.chat.mapper; -import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.MapResult; @@ -22,26 +21,22 @@ public abstract class SingleMatchStrategy extends BaseMatch protected MapperConfig mapperConfig; @Autowired protected MapperHelper mapperHelper; - @Autowired - protected ThreadPoolConfig threadPoolConfig; public List detect(ChatQueryContext chatQueryContext, List terms, - Set detectDataSetIds) { + Set detectDataSetIds) { Map regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); String text = chatQueryContext.getRequest().getQueryText(); Set results = ConcurrentHashMap.newKeySet(); - Set detectSegments = ConcurrentHashMap.newKeySet(); List> tasks = new ArrayList<>(); - for (int startIndex = 0; startIndex <= text.length() - 1;) { - for (int index = startIndex; index <= text.length();) { + for (int startIndex = 0; startIndex <= text.length() - 1; ) { + for (int index = startIndex; index <= text.length(); ) { int offset = mapperHelper.getStepOffset(terms, startIndex); index = mapperHelper.getStepIndex(regOffsetToLength, index); if (index <= text.length()) { String detectSegment = text.substring(startIndex, index).trim(); - detectSegments.add(detectSegment); - tasks.add(createTask(chatQueryContext, detectDataSetIds, detectSegment, offset, - results)); + Callable task = createTask(chatQueryContext, detectDataSetIds, detectSegment, offset, results); + tasks.add(task); } } startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); @@ -51,7 +46,7 @@ public abstract class SingleMatchStrategy extends BaseMatch } private Callable createTask(ChatQueryContext chatQueryContext, Set detectDataSetIds, - String detectSegment, int offset, Set results) { + String detectSegment, int offset, Set results) { return () -> { List oneRoundResults = detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset); @@ -62,15 +57,6 @@ public abstract class SingleMatchStrategy extends BaseMatch }; } - private void executeTasks(List> tasks) { - try { - threadPoolConfig.getMapExecutor().invokeAll(tasks); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Task execution interrupted", e); - } - } - public abstract List detectByStep(ChatQueryContext chatQueryContext, - Set detectDataSetIds, String detectSegment, int offset); + Set detectDataSetIds, String detectSegment, int offset); }