diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ThreadPoolConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ThreadPoolConfig.java new file mode 100644 index 000000000..8c8cbb20a --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/config/ThreadPoolConfig.java @@ -0,0 +1,29 @@ +package com.tencent.supersonic.common.config; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.springframework.context.annotation.Bean; +import org.springframework.stereotype.Component; + +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +@Component +public class ThreadPoolConfig { + + @Bean("commonExecutor") + public ThreadPoolExecutor getCommonExecutor() { + return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, + new LinkedBlockingQueue(1024), + new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(), + new ThreadPoolExecutor.CallerRunsPolicy()); + } + + @Bean("mapExecutor") + public ThreadPoolExecutor getMapExecutor() { + return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, + new LinkedBlockingQueue(8192), + new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(), + new ThreadPoolExecutor.CallerRunsPolicy()); + } +} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java index 0de3df0a1..770a7f097 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.chat.mapper; +import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.MapResult; @@ -8,10 +9,11 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; @Service @Slf4j @@ -20,33 +22,55 @@ public abstract class SingleMatchStrategy extends BaseMatch protected MapperConfig mapperConfig; @Autowired protected MapperHelper mapperHelper; + @Autowired + protected ThreadPoolConfig threadPoolConfig; public List detect(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { Map regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); String text = chatQueryContext.getRequest().getQueryText(); - Set results = new HashSet<>(); + Set results = ConcurrentHashMap.newKeySet(); + Set detectSegments = ConcurrentHashMap.newKeySet(); + List> tasks = new ArrayList<>(); - Set detectSegments = new HashSet<>(); - - for (Integer startIndex = 0; startIndex <= text.length() - 1;) { - - for (Integer index = startIndex; index <= text.length();) { + for (int startIndex = 0; startIndex <= text.length() - 1;) { + for (int index = startIndex; index <= text.length();) { int offset = mapperHelper.getStepOffset(terms, startIndex); index = mapperHelper.getStepIndex(regOffsetToLength, index); if (index <= text.length()) { String detectSegment = text.substring(startIndex, index).trim(); detectSegments.add(detectSegment); - List oneRoundResults = - detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset); - selectResultInOneRound(results, oneRoundResults); + tasks.add(createTask(chatQueryContext, detectDataSetIds, detectSegment, offset, + results)); } } startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); } + executeTasks(tasks); return new ArrayList<>(results); } + private Callable createTask(ChatQueryContext chatQueryContext, Set detectDataSetIds, + String detectSegment, int offset, Set results) { + return () -> { + List oneRoundResults = + detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset); + synchronized (results) { + selectResultInOneRound(results, oneRoundResults); + } + return null; + }; + } + + private void executeTasks(List> tasks) { + try { + threadPoolConfig.getMapExecutor().invokeAll(tasks); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Task execution interrupted", e); + } + } + public abstract List detectByStep(ChatQueryContext chatQueryContext, Set detectDataSetIds, String detectSegment, int offset); }