[improvement][chat] Use a generic thread pool to perform concurrent mapping. (#1965)

This commit is contained in:
lexluo09
2024-12-21 11:58:02 +08:00
committed by GitHub
parent c2d155705f
commit f7fce0217f
2 changed files with 63 additions and 10 deletions

View File

@@ -0,0 +1,29 @@
package com.tencent.supersonic.common.config;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.stereotype.Component;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@Component
public class ThreadPoolConfig {
@Bean("commonExecutor")
public ThreadPoolExecutor getCommonExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>(1024),
new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy());
}
@Bean("mapExecutor")
public ThreadPoolExecutor getMapExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>(8192),
new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy());
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.chat.mapper;
import com.tencent.supersonic.common.config.ThreadPoolConfig;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
@@ -8,10 +9,11 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
@Service
@Slf4j
@@ -20,33 +22,55 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
protected MapperConfig mapperConfig;
@Autowired
protected MapperHelper mapperHelper;
@Autowired
protected ThreadPoolConfig threadPoolConfig;
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
String text = chatQueryContext.getRequest().getQueryText();
Set<T> results = new HashSet<>();
Set<T> results = ConcurrentHashMap.newKeySet();
Set<String> detectSegments = ConcurrentHashMap.newKeySet();
List<Callable<Void>> tasks = new ArrayList<>();
Set<String> detectSegments = new HashSet<>();
for (Integer startIndex = 0; startIndex <= text.length() - 1;) {
for (Integer index = startIndex; index <= text.length();) {
for (int startIndex = 0; startIndex <= text.length() - 1;) {
for (int index = startIndex; index <= text.length();) {
int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment);
List<T> oneRoundResults =
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
selectResultInOneRound(results, oneRoundResults);
tasks.add(createTask(chatQueryContext, detectDataSetIds, detectSegment, offset,
results));
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
executeTasks(tasks);
return new ArrayList<>(results);
}
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
String detectSegment, int offset, Set<T> results) {
return () -> {
List<T> oneRoundResults =
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
synchronized (results) {
selectResultInOneRound(results, oneRoundResults);
}
return null;
};
}
private void executeTasks(List<Callable<Void>> tasks) {
try {
threadPoolConfig.getMapExecutor().invokeAll(tasks);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Task execution interrupted", e);
}
}
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, String detectSegment, int offset);
}