mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement][chat] Change the embedding to execute in parallel (#1967)
This commit is contained in:
@@ -14,23 +14,25 @@ public class ThreadPoolConfig {
|
|||||||
@Bean("commonExecutor")
|
@Bean("commonExecutor")
|
||||||
public ThreadPoolExecutor getCommonExecutor() {
|
public ThreadPoolExecutor getCommonExecutor() {
|
||||||
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
|
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
|
||||||
new LinkedBlockingQueue<Runnable>(1024),
|
new LinkedBlockingQueue<>(1024),
|
||||||
new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(),
|
new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(),
|
||||||
new ThreadPoolExecutor.CallerRunsPolicy());
|
new ThreadPoolExecutor.CallerRunsPolicy());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Bean("mapExecutor")
|
@Bean("mapExecutor")
|
||||||
public ThreadPoolExecutor getMapExecutor() {
|
public ThreadPoolExecutor getMapExecutor() {
|
||||||
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
|
return new ThreadPoolExecutor(
|
||||||
new LinkedBlockingQueue<Runnable>(8192),
|
8, 16, 60 * 3, TimeUnit.SECONDS,
|
||||||
|
new LinkedBlockingQueue<>(),
|
||||||
new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(),
|
new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(),
|
||||||
new ThreadPoolExecutor.CallerRunsPolicy());
|
new ThreadPoolExecutor.CallerRunsPolicy()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Bean("chatExecutor")
|
@Bean("chatExecutor")
|
||||||
public ThreadPoolExecutor getChatExecutor() {
|
public ThreadPoolExecutor getChatExecutor() {
|
||||||
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
|
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
|
||||||
new LinkedBlockingQueue<Runnable>(1024),
|
new LinkedBlockingQueue<>(1024),
|
||||||
new ThreadFactoryBuilder().setNameFormat("supersonic-chat-pool-").build(),
|
new ThreadFactoryBuilder().setNameFormat("supersonic-chat-pool-").build(),
|
||||||
new ThreadPoolExecutor.CallerRunsPolicy());
|
new ThreadPoolExecutor.CallerRunsPolicy());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.ThreadPoolConfig;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
@@ -7,6 +8,7 @@ import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -14,10 +16,15 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
|
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
protected ThreadPoolConfig threadPoolConfig;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectDataSetIds) {
|
||||||
@@ -63,6 +70,14 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void executeTasks(List<Callable<Void>> tasks) {
|
||||||
|
try {
|
||||||
|
threadPoolConfig.getMapExecutor().invokeAll(tasks);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
throw new RuntimeException("Task execution interrupted", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||||
if (MapModeEnum.STRICT.equals(mapModeEnum)) {
|
if (MapModeEnum.STRICT.equals(mapModeEnum)) {
|
||||||
return 1.0d;
|
return 1.0d;
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.config.ThreadPoolConfig;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||||
@@ -16,10 +17,11 @@ import org.springframework.stereotype.Service;
|
|||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
|
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
|
||||||
@@ -36,11 +38,13 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private MetaEmbeddingService metaEmbeddingService;
|
private MetaEmbeddingService metaEmbeddingService;
|
||||||
|
@Autowired
|
||||||
|
protected ThreadPoolConfig threadPoolConfig;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||||
Set<EmbeddingResult> results = new HashSet<>();
|
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
|
||||||
int embeddingMapperBatch = Integer
|
int embeddingMapperBatch = Integer
|
||||||
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||||
|
|
||||||
@@ -52,16 +56,28 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
List<List<String>> queryTextsSubList =
|
List<List<String>> queryTextsSubList =
|
||||||
Lists.partition(queryTextsList, embeddingMapperBatch);
|
Lists.partition(queryTextsList, embeddingMapperBatch);
|
||||||
|
|
||||||
|
List<Callable<Void>> tasks = new ArrayList<>();
|
||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
List<EmbeddingResult> oneRoundResults =
|
tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results));
|
||||||
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
|
|
||||||
selectResultInOneRound(results, oneRoundResults);
|
|
||||||
}
|
}
|
||||||
|
executeTasks(tasks);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
||||||
|
List<String> queryTextsSub, Set<EmbeddingResult> results) {
|
||||||
|
return () -> {
|
||||||
|
List<EmbeddingResult> oneRoundResults =
|
||||||
|
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
|
||||||
|
synchronized (results) {
|
||||||
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
|
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
|
||||||
List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
|
List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
|
||||||
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
|
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
|
||||||
double threshold =
|
double threshold =
|
||||||
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ThreadPoolConfig;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
||||||
@@ -22,26 +21,22 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
|
|||||||
protected MapperConfig mapperConfig;
|
protected MapperConfig mapperConfig;
|
||||||
@Autowired
|
@Autowired
|
||||||
protected MapperHelper mapperHelper;
|
protected MapperHelper mapperHelper;
|
||||||
@Autowired
|
|
||||||
protected ThreadPoolConfig threadPoolConfig;
|
|
||||||
|
|
||||||
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectDataSetIds) {
|
||||||
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
|
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
|
||||||
String text = chatQueryContext.getRequest().getQueryText();
|
String text = chatQueryContext.getRequest().getQueryText();
|
||||||
Set<T> results = ConcurrentHashMap.newKeySet();
|
Set<T> results = ConcurrentHashMap.newKeySet();
|
||||||
Set<String> detectSegments = ConcurrentHashMap.newKeySet();
|
|
||||||
List<Callable<Void>> tasks = new ArrayList<>();
|
List<Callable<Void>> tasks = new ArrayList<>();
|
||||||
|
|
||||||
for (int startIndex = 0; startIndex <= text.length() - 1;) {
|
for (int startIndex = 0; startIndex <= text.length() - 1; ) {
|
||||||
for (int index = startIndex; index <= text.length();) {
|
for (int index = startIndex; index <= text.length(); ) {
|
||||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
if (index <= text.length()) {
|
if (index <= text.length()) {
|
||||||
String detectSegment = text.substring(startIndex, index).trim();
|
String detectSegment = text.substring(startIndex, index).trim();
|
||||||
detectSegments.add(detectSegment);
|
Callable<Void> task = createTask(chatQueryContext, detectDataSetIds, detectSegment, offset, results);
|
||||||
tasks.add(createTask(chatQueryContext, detectDataSetIds, detectSegment, offset,
|
tasks.add(task);
|
||||||
results));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
@@ -51,7 +46,7 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
||||||
String detectSegment, int offset, Set<T> results) {
|
String detectSegment, int offset, Set<T> results) {
|
||||||
return () -> {
|
return () -> {
|
||||||
List<T> oneRoundResults =
|
List<T> oneRoundResults =
|
||||||
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
||||||
@@ -62,15 +57,6 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeTasks(List<Callable<Void>> tasks) {
|
|
||||||
try {
|
|
||||||
threadPoolConfig.getMapExecutor().invokeAll(tasks);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
Thread.currentThread().interrupt();
|
|
||||||
throw new RuntimeException("Task execution interrupted", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
|
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
|
||||||
Set<Long> detectDataSetIds, String detectSegment, int offset);
|
Set<Long> detectDataSetIds, String detectSegment, int offset);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user