[improvement][chat] Change the embedding to execute in parallel (#1967)

This commit is contained in:
lexluo09
2024-12-21 20:32:03 +08:00
committed by GitHub
parent 7dc013dfb3
commit 8c6ae62522
4 changed files with 52 additions and 33 deletions

View File

@@ -14,23 +14,25 @@ public class ThreadPoolConfig {
@Bean("commonExecutor") @Bean("commonExecutor")
public ThreadPoolExecutor getCommonExecutor() { public ThreadPoolExecutor getCommonExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>(1024), new LinkedBlockingQueue<>(1024),
new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(), new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy()); new ThreadPoolExecutor.CallerRunsPolicy());
} }
@Bean("mapExecutor") @Bean("mapExecutor")
public ThreadPoolExecutor getMapExecutor() { public ThreadPoolExecutor getMapExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, return new ThreadPoolExecutor(
new LinkedBlockingQueue<Runnable>(8192), 8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(),
new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(), new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy()); new ThreadPoolExecutor.CallerRunsPolicy()
);
} }
@Bean("chatExecutor") @Bean("chatExecutor")
public ThreadPoolExecutor getChatExecutor() { public ThreadPoolExecutor getChatExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>(1024), new LinkedBlockingQueue<>(1024),
new ThreadFactoryBuilder().setNameFormat("supersonic-chat-pool-").build(), new ThreadFactoryBuilder().setNameFormat("supersonic-chat-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy()); new ThreadPoolExecutor.CallerRunsPolicy());
} }

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.chat.mapper; package com.tencent.supersonic.headless.chat.mapper;
import com.tencent.supersonic.common.config.ThreadPoolConfig;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
@@ -7,6 +8,7 @@ import com.tencent.supersonic.headless.chat.knowledge.MapResult;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.HashMap; import java.util.HashMap;
@@ -14,10 +16,15 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Callable;
@Service @Service
@Slf4j @Slf4j
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> { public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
@Autowired
protected ThreadPoolConfig threadPoolConfig;
@Override @Override
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms, public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) { Set<Long> detectDataSetIds) {
@@ -63,6 +70,14 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
} }
} }
protected void executeTasks(List<Callable<Void>> tasks) {
try {
threadPoolConfig.getMapExecutor().invokeAll(tasks);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Task execution interrupted", e);
}
}
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) { public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
if (MapModeEnum.STRICT.equals(mapModeEnum)) { if (MapModeEnum.STRICT.equals(mapModeEnum)) {
return 1.0d; return 1.0d;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.mapper; package com.tencent.supersonic.headless.chat.mapper;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.ThreadPoolConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
@@ -16,10 +17,11 @@ import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER; import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
@@ -36,11 +38,13 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
@Autowired @Autowired
private MetaEmbeddingService metaEmbeddingService; private MetaEmbeddingService metaEmbeddingService;
@Autowired
protected ThreadPoolConfig threadPoolConfig;
@Override @Override
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext, public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments) { Set<Long> detectDataSetIds, Set<String> detectSegments) {
Set<EmbeddingResult> results = new HashSet<>(); Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
int embeddingMapperBatch = Integer int embeddingMapperBatch = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
@@ -52,16 +56,28 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
List<List<String>> queryTextsSubList = List<List<String>> queryTextsSubList =
Lists.partition(queryTextsList, embeddingMapperBatch); Lists.partition(queryTextsList, embeddingMapperBatch);
List<Callable<Void>> tasks = new ArrayList<>();
for (List<String> queryTextsSub : queryTextsSubList) { for (List<String> queryTextsSub : queryTextsSubList) {
List<EmbeddingResult> oneRoundResults = tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results));
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
selectResultInOneRound(results, oneRoundResults);
} }
executeTasks(tasks);
return new ArrayList<>(results); return new ArrayList<>(results);
} }
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
List<String> queryTextsSub, Set<EmbeddingResult> results) {
return () -> {
List<EmbeddingResult> oneRoundResults =
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
synchronized (results) {
selectResultInOneRound(results, oneRoundResults);
}
return null;
};
}
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds, private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
List<String> queryTextsSub, ChatQueryContext chatQueryContext) { List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds(); Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
double threshold = double threshold =
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD)); Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.mapper; package com.tencent.supersonic.headless.chat.mapper;
import com.tencent.supersonic.common.config.ThreadPoolConfig;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.MapResult; import com.tencent.supersonic.headless.chat.knowledge.MapResult;
@@ -22,26 +21,22 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
protected MapperConfig mapperConfig; protected MapperConfig mapperConfig;
@Autowired @Autowired
protected MapperHelper mapperHelper; protected MapperHelper mapperHelper;
@Autowired
protected ThreadPoolConfig threadPoolConfig;
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms, public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) { Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
String text = chatQueryContext.getRequest().getQueryText(); String text = chatQueryContext.getRequest().getQueryText();
Set<T> results = ConcurrentHashMap.newKeySet(); Set<T> results = ConcurrentHashMap.newKeySet();
Set<String> detectSegments = ConcurrentHashMap.newKeySet();
List<Callable<Void>> tasks = new ArrayList<>(); List<Callable<Void>> tasks = new ArrayList<>();
for (int startIndex = 0; startIndex <= text.length() - 1;) { for (int startIndex = 0; startIndex <= text.length() - 1; ) {
for (int index = startIndex; index <= text.length();) { for (int index = startIndex; index <= text.length(); ) {
int offset = mapperHelper.getStepOffset(terms, startIndex); int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index); index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) { if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index).trim(); String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment); Callable<Void> task = createTask(chatQueryContext, detectDataSetIds, detectSegment, offset, results);
tasks.add(createTask(chatQueryContext, detectDataSetIds, detectSegment, offset, tasks.add(task);
results));
} }
} }
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
@@ -51,7 +46,7 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
} }
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds, private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
String detectSegment, int offset, Set<T> results) { String detectSegment, int offset, Set<T> results) {
return () -> { return () -> {
List<T> oneRoundResults = List<T> oneRoundResults =
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset); detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
@@ -62,15 +57,6 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
}; };
} }
private void executeTasks(List<Callable<Void>> tasks) {
try {
threadPoolConfig.getMapExecutor().invokeAll(tasks);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Task execution interrupted", e);
}
}
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext, public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, String detectSegment, int offset); Set<Long> detectDataSetIds, String detectSegment, int offset);
} }