2 Commits

Author SHA1 Message Date
Xiong Tenghui
80aaabe58b (improvement) (headless) Optimize the performance of the method BaseMatchStrategy.executeTasks() (#2363) (#2364)
Some checks failed
supersonic CentOS CI / build (21) (push) Has been cancelled
supersonic mac CI / build (21) (push) Has been cancelled
supersonic ubuntu CI / build (21) (push) Has been cancelled
supersonic windows CI / build (21) (push) Has been cancelled
2025-08-20 18:52:55 +08:00
lwhy
5a4fd2b888 (feature|common)Add parameter conversion for the LikeExpression in FieldValueReplaceVisitor (#2367) 2025-08-20 18:50:43 +08:00
5 changed files with 124 additions and 43 deletions

View File

@@ -46,6 +46,62 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
replaceComparisonExpression(expr); replaceComparisonExpression(expr);
} }
public void visit(LikeExpression expr) {
Expression leftExpression = expr.getLeftExpression();
Expression rightExpression = expr.getRightExpression();
if (!(leftExpression instanceof Column)) {
return;
}
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
return;
}
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
return;
}
Column column = (Column) leftExpression;
String columnName = column.getColumnName();
if (StringUtils.isEmpty(columnName)) {
return;
}
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
return;
}
if (rightExpression instanceof StringValue) {
StringValue rightStringValue = (StringValue) rightExpression;
String value = rightStringValue.getValue();
// 使用split处理方式按通配符分割字符串对每个片段进行转换
String[] parts = value.split("%", -1);
boolean changed = false;
// 处理每个部分
for (int i = 0; i < parts.length; i++) {
if (!parts[i].isEmpty()) {
String replaceValue = getReplaceValue(valueMap, parts[i]);
if (StringUtils.isNotEmpty(replaceValue) && !parts[i].equals(replaceValue)) {
parts[i] = replaceValue;
changed = true;
}
}
}
// 如果有任何部分发生变化,则重新构建字符串
if (changed) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < parts.length; i++) {
sb.append(parts[i]);
// 除了最后一个部分,其他部分后面都需要加上"%"
if (i < parts.length - 1) {
sb.append("%");
}
}
rightStringValue.setValue(sb.toString());
}
}
}
public void visit(InExpression inExpression) { public void visit(InExpression inExpression) {
if (!(inExpression.getLeftExpression() instanceof Column)) { if (!(inExpression.getLeftExpression() instanceof Column)) {
return; return;

View File

@@ -12,12 +12,17 @@ import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
@Service @Service
@Slf4j @Slf4j
@@ -72,18 +77,39 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
} }
} }
protected void executeTasks(List<Callable<Void>> tasks) { protected Set<T> executeTasks(List<Supplier<List<T>>> tasks) {
try {
executor.invokeAll(tasks); Function<Supplier<List<T>>, Supplier<List<T>>> decorator = taskDecorator();
for (Callable<Void> future : tasks) { List<CompletableFuture<List<T>>> futures;
future.call(); if (decorator == null) {
futures = tasks.stream().map(t -> CompletableFuture.supplyAsync(t, executor)).toList();
} else {
futures = tasks.stream()
.map(t -> CompletableFuture.supplyAsync(decorator.apply(t), executor)).toList();
} }
} catch (Exception e) {
CompletableFuture<List<T>> listCompletableFuture =
CompletableFuture.allOf(futures.toArray(new CompletableFuture<?>[0]))
.thenApply(v -> futures.stream()
.flatMap(listFuture -> listFuture.join().stream())
.collect(Collectors.toList()));
try {
List<T> ts = listCompletableFuture.get();
Set<T> results = new HashSet<>();
selectResultInOneRound(results, ts);
return results;
} catch (InterruptedException e) {
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
throw new RuntimeException("Task execution interrupted", e); throw new RuntimeException("Task execution interrupted", e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
} }
} }
public Function<Supplier<List<T>>, Supplier<List<T>>> taskDecorator() {
return null;
}
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) { public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
if (MapModeEnum.STRICT.equals(mapModeEnum)) { if (MapModeEnum.STRICT.equals(mapModeEnum)) {
return 1.0d; return 1.0d;

View File

@@ -17,6 +17,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@@ -76,6 +78,22 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
return allElements; return allElements;
} }
@Override
public Function<Supplier<List<DatabaseMapResult>>, Supplier<List<DatabaseMapResult>>> taskDecorator() {
List<SchemaElement> schemaElements = allElements.get();
if (CollectionUtils.isEmpty(schemaElements)) {
return null;
}
return (t) -> (Supplier<List<DatabaseMapResult>>) () -> {
try {
allElements.set(schemaElements);
return t.get();
} finally {
allElements.remove();
}
};
}
private Double getThreshold(ChatQueryContext chatQueryContext) { private Double getThreshold(ChatQueryContext chatQueryContext) {
Double threshold = Double threshold =
Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD)); Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));

View File

@@ -24,8 +24,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.*; import java.util.*;
import java.util.concurrent.Callable; import java.util.function.Supplier;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*; import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*;
@@ -141,7 +140,6 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
*/ */
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext, public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) { Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
int embeddingMapperBatch = Integer int embeddingMapperBatch = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
@@ -154,12 +152,11 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
Lists.partition(queryTextsList, embeddingMapperBatch); Lists.partition(queryTextsList, embeddingMapperBatch);
// Create and execute tasks for each batch // Create and execute tasks for each batch
List<Callable<Void>> tasks = new ArrayList<>(); List<Supplier<List<EmbeddingResult>>> tasks = new ArrayList<>();
for (List<String> queryTextsSub : queryTextsSubList) { for (List<String> queryTextsSub : queryTextsSubList) {
tasks.add( tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, useLlm));
createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results, useLlm));
} }
executeTasks(tasks); Set<EmbeddingResult> results = executeTasks(tasks);
// Apply LLM filtering if enabled // Apply LLM filtering if enabled
if (useLlm) { if (useLlm) {
@@ -196,20 +193,13 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
* @param chatQueryContext The context of the chat query * @param chatQueryContext The context of the chat query
* @param detectDataSetIds Target dataset IDs * @param detectDataSetIds Target dataset IDs
* @param queryTextsSub Sub-list of query texts to process * @param queryTextsSub Sub-list of query texts to process
* @param results Shared result set for collecting results
* @param useLlm Whether to use LLM * @param useLlm Whether to use LLM
* @return Callable task * @return Supplier task
*/ */
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds, private Supplier<List<EmbeddingResult>> createTask(ChatQueryContext chatQueryContext,
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) { Set<Long> detectDataSetIds, List<String> queryTextsSub, boolean useLlm) {
return () -> { return () -> detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext,
List<EmbeddingResult> oneRoundResults = detectByQueryTextsSub(detectDataSetIds, useLlm);
queryTextsSub, chatQueryContext, useLlm);
synchronized (results) {
selectResultInOneRound(results, oneRoundResults);
}
return null;
};
} }
/** /**

View File

@@ -11,8 +11,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Callable; import java.util.function.Supplier;
import java.util.concurrent.ConcurrentHashMap;
@Service @Service
@Slf4j @Slf4j
@@ -26,8 +25,7 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
Set<Long> detectDataSetIds) { Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
String text = chatQueryContext.getRequest().getQueryText(); String text = chatQueryContext.getRequest().getQueryText();
Set<T> results = ConcurrentHashMap.newKeySet(); List<Supplier<List<T>>> tasks = new ArrayList<>();
List<Callable<Void>> tasks = new ArrayList<>();
for (int startIndex = 0; startIndex <= text.length() - 1;) { for (int startIndex = 0; startIndex <= text.length() - 1;) {
for (int index = startIndex; index <= text.length();) { for (int index = startIndex; index <= text.length();) {
@@ -35,27 +33,20 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
index = mapperHelper.getStepIndex(regOffsetToLength, index); index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) { if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index).trim(); String detectSegment = text.substring(startIndex, index).trim();
Callable<Void> task = createTask(chatQueryContext, detectDataSetIds, Supplier<List<T>> task =
detectSegment, offset, results); createTask(chatQueryContext, detectDataSetIds, detectSegment, offset);
tasks.add(task); tasks.add(task);
} }
} }
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
} }
executeTasks(tasks); Set<T> results = executeTasks(tasks);
return new ArrayList<>(results); return new ArrayList<>(results);
} }
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds, private Supplier<List<T>> createTask(ChatQueryContext chatQueryContext,
String detectSegment, int offset, Set<T> results) { Set<Long> detectDataSetIds, String detectSegment, int offset) {
return () -> { return () -> detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
List<T> oneRoundResults =
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
synchronized (results) {
selectResultInOneRound(results, oneRoundResults);
}
return null;
};
} }
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext, public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,