diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index 6c02674f5..025acd5e5 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -12,12 +12,17 @@ import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ThreadPoolExecutor; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; @Service @Slf4j @@ -72,18 +77,39 @@ public abstract class BaseMatchStrategy implements MatchStr } } - protected void executeTasks(List> tasks) { + protected Set executeTasks(List>> tasks) { + + Function>, Supplier>> decorator = taskDecorator(); + List>> futures; + if (decorator == null) { + futures = tasks.stream().map(t -> CompletableFuture.supplyAsync(t, executor)).toList(); + } else { + futures = tasks.stream() + .map(t -> CompletableFuture.supplyAsync(decorator.apply(t), executor)).toList(); + } + + CompletableFuture> listCompletableFuture = + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) + .thenApply(v -> futures.stream() + .flatMap(listFuture -> listFuture.join().stream()) + .collect(Collectors.toList())); try { - executor.invokeAll(tasks); - for (Callable future : tasks) { - future.call(); - } - } catch (Exception e) { + List ts = listCompletableFuture.get(); + Set results = new HashSet<>(); + selectResultInOneRound(results, ts); + return results; + } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException("Task execution interrupted", e); + } catch (ExecutionException e) { + throw new RuntimeException(e); } } + public Function>, Supplier>> taskDecorator() { + return null; + } + public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) { if (MapModeEnum.STRICT.equals(mapModeEnum)) { return 1.0d; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java index 5cc2d2fba..5c53a898d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java @@ -17,6 +17,8 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -76,6 +78,22 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy>, Supplier>> taskDecorator() { + List schemaElements = allElements.get(); + if (CollectionUtils.isEmpty(schemaElements)) { + return null; + } + return (t) -> (Supplier>) () -> { + try { + allElements.set(schemaElements); + return t.get(); + } finally { + allElements.remove(); + } + }; + } + private Double getThreshold(ChatQueryContext chatQueryContext) { Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD)); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index 009d5b166..45593c1df 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -24,8 +24,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.*; -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; import java.util.stream.Collectors; import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*; @@ -141,7 +140,6 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy */ public List detectByBatch(ChatQueryContext chatQueryContext, Set detectDataSetIds, Set detectSegments, boolean useLlm) { - Set results = ConcurrentHashMap.newKeySet(); int embeddingMapperBatch = Integer .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); @@ -154,12 +152,11 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy Lists.partition(queryTextsList, embeddingMapperBatch); // Create and execute tasks for each batch - List> tasks = new ArrayList<>(); + List>> tasks = new ArrayList<>(); for (List queryTextsSub : queryTextsSubList) { - tasks.add( - createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results, useLlm)); + tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, useLlm)); } - executeTasks(tasks); + Set results = executeTasks(tasks); // Apply LLM filtering if enabled if (useLlm) { @@ -196,20 +193,13 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy * @param chatQueryContext The context of the chat query * @param detectDataSetIds Target dataset IDs * @param queryTextsSub Sub-list of query texts to process - * @param results Shared result set for collecting results * @param useLlm Whether to use LLM - * @return Callable task + * @return Supplier task */ - private Callable createTask(ChatQueryContext chatQueryContext, Set detectDataSetIds, - List queryTextsSub, Set results, boolean useLlm) { - return () -> { - List oneRoundResults = detectByQueryTextsSub(detectDataSetIds, - queryTextsSub, chatQueryContext, useLlm); - synchronized (results) { - selectResultInOneRound(results, oneRoundResults); - } - return null; - }; + private Supplier> createTask(ChatQueryContext chatQueryContext, + Set detectDataSetIds, List queryTextsSub, boolean useLlm) { + return () -> detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext, + useLlm); } /** diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java index 508ca8cef..1aa00365f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java @@ -11,8 +11,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; @Service @Slf4j @@ -26,8 +25,7 @@ public abstract class SingleMatchStrategy extends BaseMatch Set detectDataSetIds) { Map regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); String text = chatQueryContext.getRequest().getQueryText(); - Set results = ConcurrentHashMap.newKeySet(); - List> tasks = new ArrayList<>(); + List>> tasks = new ArrayList<>(); for (int startIndex = 0; startIndex <= text.length() - 1;) { for (int index = startIndex; index <= text.length();) { @@ -35,27 +33,20 @@ public abstract class SingleMatchStrategy extends BaseMatch index = mapperHelper.getStepIndex(regOffsetToLength, index); if (index <= text.length()) { String detectSegment = text.substring(startIndex, index).trim(); - Callable task = createTask(chatQueryContext, detectDataSetIds, - detectSegment, offset, results); + Supplier> task = + createTask(chatQueryContext, detectDataSetIds, detectSegment, offset); tasks.add(task); } } startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); } - executeTasks(tasks); + Set results = executeTasks(tasks); return new ArrayList<>(results); } - private Callable createTask(ChatQueryContext chatQueryContext, Set detectDataSetIds, - String detectSegment, int offset, Set results) { - return () -> { - List oneRoundResults = - detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset); - synchronized (results) { - selectResultInOneRound(results, oneRoundResults); - } - return null; - }; + private Supplier> createTask(ChatQueryContext chatQueryContext, + Set detectDataSetIds, String detectSegment, int offset) { + return () -> detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset); } public abstract List detectByStep(ChatQueryContext chatQueryContext,