mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Compare commits
2 Commits
5df0b87da9
...
80aaabe58b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80aaabe58b | ||
|
|
5a4fd2b888 |
@@ -46,6 +46,62 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
|
|||||||
replaceComparisonExpression(expr);
|
replaceComparisonExpression(expr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void visit(LikeExpression expr) {
|
||||||
|
Expression leftExpression = expr.getLeftExpression();
|
||||||
|
Expression rightExpression = expr.getRightExpression();
|
||||||
|
|
||||||
|
if (!(leftExpression instanceof Column)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Column column = (Column) leftExpression;
|
||||||
|
String columnName = column.getColumnName();
|
||||||
|
if (StringUtils.isEmpty(columnName)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
|
||||||
|
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (rightExpression instanceof StringValue) {
|
||||||
|
StringValue rightStringValue = (StringValue) rightExpression;
|
||||||
|
String value = rightStringValue.getValue();
|
||||||
|
|
||||||
|
// 使用split处理方式,按通配符分割字符串,对每个片段进行转换
|
||||||
|
String[] parts = value.split("%", -1);
|
||||||
|
boolean changed = false;
|
||||||
|
|
||||||
|
// 处理每个部分
|
||||||
|
for (int i = 0; i < parts.length; i++) {
|
||||||
|
if (!parts[i].isEmpty()) {
|
||||||
|
String replaceValue = getReplaceValue(valueMap, parts[i]);
|
||||||
|
if (StringUtils.isNotEmpty(replaceValue) && !parts[i].equals(replaceValue)) {
|
||||||
|
parts[i] = replaceValue;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有任何部分发生变化,则重新构建字符串
|
||||||
|
if (changed) {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (int i = 0; i < parts.length; i++) {
|
||||||
|
sb.append(parts[i]);
|
||||||
|
// 除了最后一个部分,其他部分后面都需要加上"%"
|
||||||
|
if (i < parts.length - 1) {
|
||||||
|
sb.append("%");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rightStringValue.setValue(sb.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void visit(InExpression inExpression) {
|
public void visit(InExpression inExpression) {
|
||||||
if (!(inExpression.getLeftExpression() instanceof Column)) {
|
if (!(inExpression.getLeftExpression() instanceof Column)) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -12,12 +12,17 @@ import org.springframework.beans.factory.annotation.Qualifier;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.Callable;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.ThreadPoolExecutor;
|
import java.util.concurrent.ThreadPoolExecutor;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -72,18 +77,39 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void executeTasks(List<Callable<Void>> tasks) {
|
protected Set<T> executeTasks(List<Supplier<List<T>>> tasks) {
|
||||||
try {
|
|
||||||
executor.invokeAll(tasks);
|
Function<Supplier<List<T>>, Supplier<List<T>>> decorator = taskDecorator();
|
||||||
for (Callable<Void> future : tasks) {
|
List<CompletableFuture<List<T>>> futures;
|
||||||
future.call();
|
if (decorator == null) {
|
||||||
|
futures = tasks.stream().map(t -> CompletableFuture.supplyAsync(t, executor)).toList();
|
||||||
|
} else {
|
||||||
|
futures = tasks.stream()
|
||||||
|
.map(t -> CompletableFuture.supplyAsync(decorator.apply(t), executor)).toList();
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
|
||||||
|
CompletableFuture<List<T>> listCompletableFuture =
|
||||||
|
CompletableFuture.allOf(futures.toArray(new CompletableFuture<?>[0]))
|
||||||
|
.thenApply(v -> futures.stream()
|
||||||
|
.flatMap(listFuture -> listFuture.join().stream())
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
try {
|
||||||
|
List<T> ts = listCompletableFuture.get();
|
||||||
|
Set<T> results = new HashSet<>();
|
||||||
|
selectResultInOneRound(results, ts);
|
||||||
|
return results;
|
||||||
|
} catch (InterruptedException e) {
|
||||||
Thread.currentThread().interrupt();
|
Thread.currentThread().interrupt();
|
||||||
throw new RuntimeException("Task execution interrupted", e);
|
throw new RuntimeException("Task execution interrupted", e);
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Function<Supplier<List<T>>, Supplier<List<T>>> taskDecorator() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||||
if (MapModeEnum.STRICT.equals(mapModeEnum)) {
|
if (MapModeEnum.STRICT.equals(mapModeEnum)) {
|
||||||
return 1.0d;
|
return 1.0d;
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Map.Entry;
|
import java.util.Map.Entry;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Supplier;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -76,6 +78,22 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
|
|||||||
return allElements;
|
return allElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Function<Supplier<List<DatabaseMapResult>>, Supplier<List<DatabaseMapResult>>> taskDecorator() {
|
||||||
|
List<SchemaElement> schemaElements = allElements.get();
|
||||||
|
if (CollectionUtils.isEmpty(schemaElements)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return (t) -> (Supplier<List<DatabaseMapResult>>) () -> {
|
||||||
|
try {
|
||||||
|
allElements.set(schemaElements);
|
||||||
|
return t.get();
|
||||||
|
} finally {
|
||||||
|
allElements.remove();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private Double getThreshold(ChatQueryContext chatQueryContext) {
|
private Double getThreshold(ChatQueryContext chatQueryContext) {
|
||||||
Double threshold =
|
Double threshold =
|
||||||
Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.Callable;
|
import java.util.function.Supplier;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*;
|
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*;
|
||||||
@@ -141,7 +140,6 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
*/
|
*/
|
||||||
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||||
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
|
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
|
||||||
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
|
|
||||||
int embeddingMapperBatch = Integer
|
int embeddingMapperBatch = Integer
|
||||||
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||||
|
|
||||||
@@ -154,12 +152,11 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
Lists.partition(queryTextsList, embeddingMapperBatch);
|
Lists.partition(queryTextsList, embeddingMapperBatch);
|
||||||
|
|
||||||
// Create and execute tasks for each batch
|
// Create and execute tasks for each batch
|
||||||
List<Callable<Void>> tasks = new ArrayList<>();
|
List<Supplier<List<EmbeddingResult>>> tasks = new ArrayList<>();
|
||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
tasks.add(
|
tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, useLlm));
|
||||||
createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results, useLlm));
|
|
||||||
}
|
}
|
||||||
executeTasks(tasks);
|
Set<EmbeddingResult> results = executeTasks(tasks);
|
||||||
|
|
||||||
// Apply LLM filtering if enabled
|
// Apply LLM filtering if enabled
|
||||||
if (useLlm) {
|
if (useLlm) {
|
||||||
@@ -196,20 +193,13 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
* @param chatQueryContext The context of the chat query
|
* @param chatQueryContext The context of the chat query
|
||||||
* @param detectDataSetIds Target dataset IDs
|
* @param detectDataSetIds Target dataset IDs
|
||||||
* @param queryTextsSub Sub-list of query texts to process
|
* @param queryTextsSub Sub-list of query texts to process
|
||||||
* @param results Shared result set for collecting results
|
|
||||||
* @param useLlm Whether to use LLM
|
* @param useLlm Whether to use LLM
|
||||||
* @return Callable task
|
* @return Supplier task
|
||||||
*/
|
*/
|
||||||
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
private Supplier<List<EmbeddingResult>> createTask(ChatQueryContext chatQueryContext,
|
||||||
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
|
Set<Long> detectDataSetIds, List<String> queryTextsSub, boolean useLlm) {
|
||||||
return () -> {
|
return () -> detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext,
|
||||||
List<EmbeddingResult> oneRoundResults = detectByQueryTextsSub(detectDataSetIds,
|
useLlm);
|
||||||
queryTextsSub, chatQueryContext, useLlm);
|
|
||||||
synchronized (results) {
|
|
||||||
selectResultInOneRound(results, oneRoundResults);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ import java.util.ArrayList;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.Callable;
|
import java.util.function.Supplier;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -26,8 +25,7 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
|
|||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectDataSetIds) {
|
||||||
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
|
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
|
||||||
String text = chatQueryContext.getRequest().getQueryText();
|
String text = chatQueryContext.getRequest().getQueryText();
|
||||||
Set<T> results = ConcurrentHashMap.newKeySet();
|
List<Supplier<List<T>>> tasks = new ArrayList<>();
|
||||||
List<Callable<Void>> tasks = new ArrayList<>();
|
|
||||||
|
|
||||||
for (int startIndex = 0; startIndex <= text.length() - 1;) {
|
for (int startIndex = 0; startIndex <= text.length() - 1;) {
|
||||||
for (int index = startIndex; index <= text.length();) {
|
for (int index = startIndex; index <= text.length();) {
|
||||||
@@ -35,27 +33,20 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
|
|||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
if (index <= text.length()) {
|
if (index <= text.length()) {
|
||||||
String detectSegment = text.substring(startIndex, index).trim();
|
String detectSegment = text.substring(startIndex, index).trim();
|
||||||
Callable<Void> task = createTask(chatQueryContext, detectDataSetIds,
|
Supplier<List<T>> task =
|
||||||
detectSegment, offset, results);
|
createTask(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
||||||
tasks.add(task);
|
tasks.add(task);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
}
|
}
|
||||||
executeTasks(tasks);
|
Set<T> results = executeTasks(tasks);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
private Supplier<List<T>> createTask(ChatQueryContext chatQueryContext,
|
||||||
String detectSegment, int offset, Set<T> results) {
|
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||||
return () -> {
|
return () -> detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
||||||
List<T> oneRoundResults =
|
|
||||||
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
|
||||||
synchronized (results) {
|
|
||||||
selectResultInOneRound(results, oneRoundResults);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
|
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
|
||||||
|
|||||||
Reference in New Issue
Block a user