mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-20 21:54:19 +08:00
Compare commits
5 Commits
master
...
58e41cd4bc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
58e41cd4bc | ||
|
|
978ae53fb3 | ||
|
|
e04bc3cce8 | ||
|
|
aaf2d46a56 | ||
|
|
c8abea9c1a |
@@ -1,10 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.headless.api.pojo;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.*;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.ToString;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@@ -21,6 +17,7 @@ public class SchemaElementMatch implements Serializable {
|
|||||||
private String word;
|
private String word;
|
||||||
private Long frequency;
|
private Long frequency;
|
||||||
private boolean isInherited;
|
private boolean isInherited;
|
||||||
|
private boolean llmMatched;
|
||||||
|
|
||||||
public boolean isFullMatched() {
|
public boolean isFullMatched() {
|
||||||
return 1.0 == similarity;
|
return 1.0 == similarity;
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ public class EmbeddingResult extends MapResult {
|
|||||||
|
|
||||||
private String id;
|
private String id;
|
||||||
private Map<String, String> metadata;
|
private Map<String, String> metadata;
|
||||||
|
private boolean llmMatched;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||||
@@ -11,6 +14,7 @@ import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
|||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -23,10 +27,16 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean accept(ChatQueryContext chatQueryContext) {
|
public boolean accept(ChatQueryContext chatQueryContext) {
|
||||||
return MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
|
boolean b0 = MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
|
||||||
|
boolean b1 = chatQueryContext.getRequest().getText2SQLType() == Text2SQLType.LLM_OR_RULE;
|
||||||
|
return b0 || b1;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void doMap(ChatQueryContext chatQueryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
|
|
||||||
|
// TODO: 如果是在LOOSE执行过了,那么在LLM_OR_RULE阶段可以不用执行,所以这里缺乏一个状态来传递,暂时先忽略这个浪费行为吧
|
||||||
|
SchemaMapInfo mappedInfo = chatQueryContext.getMapInfo();
|
||||||
|
|
||||||
// 1. Query from embedding by queryText
|
// 1. Query from embedding by queryText
|
||||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
|
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
|
||||||
@@ -53,15 +63,26 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Build SchemaElementMatch object
|
// Build SchemaElementMatch object
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||||
.word(matchResult.getName()).similarity(matchResult.getSimilarity())
|
.word(matchResult.getName()).similarity(matchResult.getSimilarity())
|
||||||
.detectWord(matchResult.getDetectWord()).build();
|
.detectWord(matchResult.getDetectWord()).build();
|
||||||
|
schemaElementMatch.setLlmMatched(matchResult.isLlmMatched());
|
||||||
|
|
||||||
// 3. Add SchemaElementMatch to mapInfo
|
// 3. Add SchemaElementMatch to mapInfo
|
||||||
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
|
if (CollectionUtils.isEmpty(matchResults)) {
|
||||||
|
log.info("embedding mapper no match");
|
||||||
|
} else {
|
||||||
|
for (EmbeddingResult matchResult : matchResults) {
|
||||||
|
log.info("embedding match name=[{}],detectWord=[{}],similarity=[{}],metadata=[{}]",
|
||||||
|
matchResult.getName(), matchResult.getDetectWord(),
|
||||||
|
matchResult.getSimilarity(), JsonUtil.toString(matchResult.getMetadata()));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||||
|
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.input.Prompt;
|
||||||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
@@ -14,18 +22,12 @@ import org.springframework.beans.BeanUtils;
|
|||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.concurrent.Callable;
|
import java.util.concurrent.Callable;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
|
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*;
|
||||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
|
|
||||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* EmbeddingMatchStrategy uses vector database to perform similarity search against the embeddings
|
* EmbeddingMatchStrategy uses vector database to perform similarity search against the embeddings
|
||||||
@@ -35,37 +37,165 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult> {
|
public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
protected MapperConfig mapperConfig;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private MetaEmbeddingService metaEmbeddingService;
|
private MetaEmbeddingService metaEmbeddingService;
|
||||||
|
|
||||||
|
private static final String LLM_FILTER_PROMPT =
|
||||||
|
"""
|
||||||
|
\
|
||||||
|
#Role: You are a professional data analyst specializing in metrics and dimensions.
|
||||||
|
#Task: Given a user query and a list of retrieved metrics/dimensions through vector recall,
|
||||||
|
please analyze which metrics/dimensions the user is most likely interested in.
|
||||||
|
#Rules:
|
||||||
|
1. Based on user query and retrieved info, accurately determine metrics/dimensions user truly cares about.
|
||||||
|
2. Do not return all retrieved info, only select those highly relevant to user query.
|
||||||
|
3. Maintain high quality output, exclude metrics/dimensions irrelevant to user intent.
|
||||||
|
4. Output must be in JSON array format, only include IDs from retrieved info, e.g.: ['id1', 'id2']
|
||||||
|
5. Return JSON content directly without markdown formatting
|
||||||
|
#Input Example:
|
||||||
|
#User Query: {{userText}}
|
||||||
|
#Retrieved Metrics/Dimensions: {{retrievedInfo}}
|
||||||
|
#Output:""";
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||||
|
Set<Long> detectDataSetIds) {
|
||||||
|
if (chatQueryContext == null || CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||||
|
log.warn("Invalid input parameters: context={}, dataSetIds={}", chatQueryContext,
|
||||||
|
detectDataSetIds);
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Base detection
|
||||||
|
List<EmbeddingResult> baseResults = super.detect(chatQueryContext, terms, detectDataSetIds);
|
||||||
|
|
||||||
|
boolean useLLM = Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM));
|
||||||
|
|
||||||
|
// 2. LLM enhanced detection
|
||||||
|
if (useLLM) {
|
||||||
|
List<EmbeddingResult> llmResults = detectWithLLM(chatQueryContext, detectDataSetIds);
|
||||||
|
if (!CollectionUtils.isEmpty(llmResults)) {
|
||||||
|
baseResults.addAll(llmResults);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Deduplicate results
|
||||||
|
return baseResults.stream().distinct().collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform enhanced detection using LLM
|
||||||
|
*/
|
||||||
|
private List<EmbeddingResult> detectWithLLM(ChatQueryContext chatQueryContext,
|
||||||
|
Set<Long> detectDataSetIds) {
|
||||||
|
try {
|
||||||
|
String queryText = chatQueryContext.getRequest().getQueryText();
|
||||||
|
if (StringUtils.isBlank(queryText)) {
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get segmentation results
|
||||||
|
Set<String> detectSegments = extractValidSegments(queryText);
|
||||||
|
if (CollectionUtils.isEmpty(detectSegments)) {
|
||||||
|
log.info("No valid segments found for text: {}", queryText);
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, true);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Error in LLM detection for context: {}", chatQueryContext, e);
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract valid word segments by filtering out unwanted word natures
|
||||||
|
*/
|
||||||
|
private Set<String> extractValidSegments(String text) {
|
||||||
|
List<String> natureList = Arrays.asList(StringUtils.split(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE ), ","));
|
||||||
|
return HanlpHelper.getSegment().seg(text).stream()
|
||||||
|
.filter(t -> natureList.stream().noneMatch(nature -> t.nature.startsWith(nature)))
|
||||||
|
.map(Term::getWord).collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||||
|
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process detection in batches with LLM option
|
||||||
|
*
|
||||||
|
* @param chatQueryContext The context of the chat query
|
||||||
|
* @param detectDataSetIds Target dataset IDs for detection
|
||||||
|
* @param detectSegments Segments to be detected
|
||||||
|
* @param useLlm Whether to use LLM for filtering results
|
||||||
|
* @return List of embedding results
|
||||||
|
*/
|
||||||
|
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||||
|
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
|
||||||
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
|
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
|
||||||
int embeddingMapperBatch = Integer
|
int embeddingMapperBatch = Integer
|
||||||
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||||
|
|
||||||
List<String> queryTextsList =
|
// Process and filter query texts
|
||||||
detectSegments.stream().map(detectSegment -> detectSegment.trim())
|
List<String> queryTextsList = detectSegments.stream().map(String::trim)
|
||||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment))
|
.filter(StringUtils::isNotBlank).collect(Collectors.toList());
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
|
// Partition queries into sub-lists for batch processing
|
||||||
List<List<String>> queryTextsSubList =
|
List<List<String>> queryTextsSubList =
|
||||||
Lists.partition(queryTextsList, embeddingMapperBatch);
|
Lists.partition(queryTextsList, embeddingMapperBatch);
|
||||||
|
|
||||||
|
// Create and execute tasks for each batch
|
||||||
List<Callable<Void>> tasks = new ArrayList<>();
|
List<Callable<Void>> tasks = new ArrayList<>();
|
||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results));
|
tasks.add(
|
||||||
|
createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results, useLlm));
|
||||||
}
|
}
|
||||||
executeTasks(tasks);
|
executeTasks(tasks);
|
||||||
|
|
||||||
|
// Apply LLM filtering if enabled
|
||||||
|
if (useLlm) {
|
||||||
|
Map<String, Object> variable = new HashMap<>();
|
||||||
|
variable.put("userText", chatQueryContext.getRequest().getQueryText());
|
||||||
|
variable.put("retrievedInfo", JSONObject.toJSONString(results));
|
||||||
|
|
||||||
|
Prompt prompt = PromptTemplate.from(LLM_FILTER_PROMPT).apply(variable);
|
||||||
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel();
|
||||||
|
String response = chatLanguageModel.generate(prompt.toUserMessage().singleText());
|
||||||
|
|
||||||
|
if (StringUtils.isBlank(response)) {
|
||||||
|
results.clear();
|
||||||
|
} else {
|
||||||
|
List<String> retrievedIds = JSONObject.parseArray(response, String.class);
|
||||||
|
results = results.stream().filter(t -> retrievedIds.contains(t.getId()))
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
results.forEach(r -> r.setLlmMatched(true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a task for batch processing
|
||||||
|
*
|
||||||
|
* @param chatQueryContext The context of the chat query
|
||||||
|
* @param detectDataSetIds Target dataset IDs
|
||||||
|
* @param queryTextsSub Sub-list of query texts to process
|
||||||
|
* @param results Shared result set for collecting results
|
||||||
|
* @param useLlm Whether to use LLM
|
||||||
|
* @return Callable task
|
||||||
|
*/
|
||||||
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
||||||
List<String> queryTextsSub, Set<EmbeddingResult> results) {
|
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
|
||||||
return () -> {
|
return () -> {
|
||||||
List<EmbeddingResult> oneRoundResults =
|
List<EmbeddingResult> oneRoundResults = detectByQueryTextsSub(detectDataSetIds,
|
||||||
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
|
queryTextsSub, chatQueryContext, useLlm);
|
||||||
synchronized (results) {
|
synchronized (results) {
|
||||||
selectResultInOneRound(results, oneRoundResults);
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
}
|
}
|
||||||
@@ -73,57 +203,73 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process a sub-list of query texts
|
||||||
|
*
|
||||||
|
* @param detectDataSetIds Target dataset IDs
|
||||||
|
* @param queryTextsSub Sub-list of query texts
|
||||||
|
* @param chatQueryContext Chat query context
|
||||||
|
* @param useLlm Whether to use LLM
|
||||||
|
* @return List of embedding results for this batch
|
||||||
|
*/
|
||||||
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
|
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
|
||||||
List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
|
List<String> queryTextsSub, ChatQueryContext chatQueryContext, boolean useLlm) {
|
||||||
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
|
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
|
||||||
|
|
||||||
|
// Get configuration parameters
|
||||||
double threshold =
|
double threshold =
|
||||||
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
Double.parseDouble(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||||
|
|
||||||
// step1. build query params
|
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
|
||||||
|
|
||||||
// step2. retrieveQuery by detectSegment
|
|
||||||
int embeddingNumber =
|
int embeddingNumber =
|
||||||
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
Integer.parseInt(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
||||||
|
int embeddingRoundNumber =
|
||||||
|
Integer.parseInt(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
||||||
|
|
||||||
|
// Build and execute query
|
||||||
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
return new ArrayList<>();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
// step3. build EmbeddingResults
|
|
||||||
List<EmbeddingResult> collect = retrieveQueryResults.stream().map(retrieveQueryResult -> {
|
// Process results
|
||||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
List<EmbeddingResult> collect = retrieveQueryResults.stream().peek(result -> {
|
||||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
if (!useLlm && CollectionUtils.isNotEmpty(result.getRetrieval())) {
|
||||||
retrievals.removeIf(retrieval -> {
|
result.getRetrieval()
|
||||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
.removeIf(retrieval -> !result.getQuery().contains(retrieval.getQuery())
|
||||||
return retrieval.getSimilarity() < threshold;
|
&& retrieval.getSimilarity() < threshold);
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
return retrieveQueryResult;
|
}).filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
|
||||||
}).filter(retrieveQueryResult -> CollectionUtils
|
.flatMap(result -> result.getRetrieval().stream()
|
||||||
.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
.map(retrieval -> convertToEmbeddingResult(result, retrieval)))
|
||||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
|
||||||
.map(retrieval -> {
|
|
||||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
|
||||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
|
||||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
|
||||||
embeddingResult.setName(retrieval.getQuery());
|
|
||||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet()
|
|
||||||
.stream().collect(Collectors.toMap(Map.Entry::getKey,
|
|
||||||
entry -> entry.getValue().toString()));
|
|
||||||
embeddingResult.setMetadata(convertedMap);
|
|
||||||
return embeddingResult;
|
|
||||||
}))
|
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
// step4. select mapResul in one round
|
// Sort and limit results
|
||||||
int embeddingRoundNumber =
|
return collect.stream()
|
||||||
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
.sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity).reversed())
|
||||||
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
|
.limit(embeddingRoundNumber * queryTextsSub.size()).collect(Collectors.toList());
|
||||||
return collect.stream().sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity))
|
}
|
||||||
.limit(roundNumber).collect(Collectors.toList());
|
|
||||||
|
/**
|
||||||
|
* Convert RetrieveQueryResult and Retrieval to EmbeddingResult
|
||||||
|
*
|
||||||
|
* @param queryResult The query result containing retrieval information
|
||||||
|
* @param retrieval The retrieval data to be converted
|
||||||
|
* @return Converted EmbeddingResult
|
||||||
|
*/
|
||||||
|
private EmbeddingResult convertToEmbeddingResult(RetrieveQueryResult queryResult,
|
||||||
|
Retrieval retrieval) {
|
||||||
|
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||||
|
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||||
|
embeddingResult.setDetectWord(queryResult.getQuery());
|
||||||
|
embeddingResult.setName(retrieval.getQuery());
|
||||||
|
|
||||||
|
// Convert metadata to string values
|
||||||
|
Map<String, String> metadata = retrieval.getMetadata().entrySet().stream().collect(
|
||||||
|
Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
|
||||||
|
embeddingResult.setMetadata(metadata);
|
||||||
|
|
||||||
|
return embeddingResult;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,12 +7,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -66,7 +61,7 @@ public class MapFilter {
|
|||||||
List<SchemaElementMatch> value = entry.getValue();
|
List<SchemaElementMatch> value = entry.getValue();
|
||||||
if (!CollectionUtils.isEmpty(value)) {
|
if (!CollectionUtils.isEmpty(value)) {
|
||||||
value.removeIf(schemaElementMatch -> StringUtils
|
value.removeIf(schemaElementMatch -> StringUtils
|
||||||
.length(schemaElementMatch.getDetectWord()) <= 1);
|
.length(schemaElementMatch.getDetectWord()) <= 1 && !schemaElementMatch.isLlmMatched());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -85,7 +80,7 @@ public class MapFilter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static void filterByQueryDataType(ChatQueryContext chatQueryContext,
|
public static void filterByQueryDataType(ChatQueryContext chatQueryContext,
|
||||||
Predicate<SchemaElement> needRemovePredicate) {
|
Predicate<SchemaElement> needRemovePredicate) {
|
||||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||||
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||||
|
|||||||
@@ -57,4 +57,12 @@ public class MapperConfig extends ParameterConfig {
|
|||||||
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
||||||
new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值",
|
new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值",
|
||||||
"向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
|
"向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MAPPER_USE_LLM =
|
||||||
|
new Parameter("s2.mapper.embedding.use-llm-enhance", "false", "使用LLM对召回的向量进行二次判断开关",
|
||||||
|
"embedding的结果再通过一次LLM来筛选,这时候忽略各个向量阀值", "bool", "Mapper相关配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE =
|
||||||
|
new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']", "使用LLM召回二次处理时对问题分词词性的控制",
|
||||||
|
"分词后允许的词性才会进行向量召回", "list", "Mapper相关配置");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
spring:
|
spring:
|
||||||
datasource:
|
datasource:
|
||||||
driver-class-name: org.postgresql.Driver
|
driver-class-name: org.postgresql.Driver
|
||||||
url: jdbc:postgresql://${S2_DB_HOST:localhost}:${S2_DB_PORT:5432}/${S2_DB_DATABASE:postgres}?stringtype=unspecified
|
url: jdbc:postgresql://localhost:5432/postgres?stringtype=unspecified
|
||||||
username: ${S2_DB_USER:postgres}
|
username: postgres
|
||||||
password: ${S2_DB_PASSWORD:postgres}
|
password: postgres
|
||||||
sql:
|
sql:
|
||||||
init:
|
init:
|
||||||
mode: always
|
mode: always
|
||||||
username: ${S2_DB_USER:postgres}
|
username: postgres
|
||||||
password: ${S2_DB_PASSWORD:postgres}
|
password: postgres
|
||||||
schema-locations: classpath:db/schema-postgres.sql,classpath:db/schema-postgres-demo.sql
|
schema-locations: classpath:db/schema-postgres.sql,classpath:db/schema-postgres-demo.sql
|
||||||
data-locations: classpath:db/data-postgres.sql,classpath:db/data-postgres-demo.sql
|
data-locations: classpath:db/data-postgres.sql,classpath:db/data-postgres-demo.sql
|
||||||
|
|
||||||
@@ -17,9 +17,9 @@ s2:
|
|||||||
store:
|
store:
|
||||||
provider: PGVECTOR
|
provider: PGVECTOR
|
||||||
base:
|
base:
|
||||||
url: ${S2_DB_HOST:127.0.0.1}
|
url: 127.0.0.1
|
||||||
port: ${S2_DB_PORT:5432}
|
port: 5432
|
||||||
databaseName: ${S2_DB_DATABASE:postgres}
|
databaseName: postgres
|
||||||
user: ${S2_DB_USER:postgres}
|
user: postgres
|
||||||
password: ${S2_DB_PASSWORD:postgres}
|
password: postgres
|
||||||
dimension: 512
|
dimension: 512
|
||||||
@@ -41,3 +41,5 @@ s2:
|
|||||||
threshold: 0.5
|
threshold: 0.5
|
||||||
min:
|
min:
|
||||||
threshold: 0.3
|
threshold: 0.3
|
||||||
|
embedding:
|
||||||
|
use-llm-enhance: true
|
||||||
|
|||||||
@@ -209,5 +209,6 @@
|
|||||||
},
|
},
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=16"
|
"node": ">=16"
|
||||||
}
|
},
|
||||||
|
"packageManager": "pnpm@9.12.3+sha512.cce0f9de9c5a7c95bef944169cc5dfe8741abfb145078c0d508b868056848a87c81e626246cb60967cbd7fd29a6c062ef73ff840d96b3c86c40ac92cf4a813ee"
|
||||||
}
|
}
|
||||||
@@ -9,7 +9,7 @@ import {
|
|||||||
RangeValue,
|
RangeValue,
|
||||||
SimilarQuestionType,
|
SimilarQuestionType,
|
||||||
} from '../../common/type';
|
} from '../../common/type';
|
||||||
import { createContext, useEffect, useRef, useState } from 'react';
|
import { createContext, useEffect, useState } from 'react';
|
||||||
import { chatExecute, chatParse, queryData, deleteQuery, switchEntity } from '../../service';
|
import { chatExecute, chatParse, queryData, deleteQuery, switchEntity } from '../../service';
|
||||||
import { PARSE_ERROR_TIP, PREFIX_CLS, SEARCH_EXCEPTION_TIP } from '../../common/constants';
|
import { PARSE_ERROR_TIP, PREFIX_CLS, SEARCH_EXCEPTION_TIP } from '../../common/constants';
|
||||||
import { message, Spin } from 'antd';
|
import { message, Spin } from 'antd';
|
||||||
@@ -490,9 +490,7 @@ const ChatItem: React.FC<Props> = ({
|
|||||||
onSwitchEntity={onSwitchEntity}
|
onSwitchEntity={onSwitchEntity}
|
||||||
onFiltersChange={onFiltersChange}
|
onFiltersChange={onFiltersChange}
|
||||||
onDateInfoChange={onDateInfoChange}
|
onDateInfoChange={onDateInfoChange}
|
||||||
onRefresh={() => {
|
onRefresh={onRefresh}
|
||||||
onRefresh();
|
|
||||||
}}
|
|
||||||
handlePresetClick={handlePresetClick}
|
handlePresetClick={handlePresetClick}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ const BarChart: React.FC<Props> = ({
|
|||||||
}) => {
|
}) => {
|
||||||
const chartRef = useRef<any>();
|
const chartRef = useRef<any>();
|
||||||
const instanceRef = useRef<ECharts>();
|
const instanceRef = useRef<ECharts>();
|
||||||
|
const { downloadChartAsImage } = useExportByEcharts({
|
||||||
|
instanceRef,
|
||||||
|
question,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { register } = useContext(ChartItemContext);
|
||||||
|
|
||||||
const { queryColumns, queryResults, entityInfo } = data;
|
const { queryColumns, queryResults, entityInfo } = data;
|
||||||
|
|
||||||
@@ -189,13 +195,6 @@ const BarChart: React.FC<Props> = ({
|
|||||||
|
|
||||||
const prefixCls = `${PREFIX_CLS}-bar`;
|
const prefixCls = `${PREFIX_CLS}-bar`;
|
||||||
|
|
||||||
const { downloadChartAsImage } = useExportByEcharts({
|
|
||||||
instanceRef,
|
|
||||||
question,
|
|
||||||
});
|
|
||||||
|
|
||||||
const { register } = useContext(ChartItemContext);
|
|
||||||
|
|
||||||
register('downloadChartAsImage', downloadChartAsImage);
|
register('downloadChartAsImage', downloadChartAsImage);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -93,7 +93,9 @@ export const getFormattedValue = (value: number | string, remainZero?: boolean)
|
|||||||
|
|
||||||
export const formatNumberWithCN = (num: number) => {
|
export const formatNumberWithCN = (num: number) => {
|
||||||
if (isNaN(num)) return '-';
|
if (isNaN(num)) return '-';
|
||||||
if (num >= 10000) {
|
if (num >= 100000000) {
|
||||||
|
return (num / 100000000).toFixed(1) + '亿';
|
||||||
|
} else if (num >= 10000) {
|
||||||
return (num / 10000).toFixed(1) + '万';
|
return (num / 10000).toFixed(1) + '万';
|
||||||
} else {
|
} else {
|
||||||
return formatByDecimalPlaces(num, 2);
|
return formatByDecimalPlaces(num, 2);
|
||||||
|
|||||||
@@ -4,5 +4,9 @@ export default {
|
|||||||
target: 'http://127.0.0.1:9080',
|
target: 'http://127.0.0.1:9080',
|
||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
},
|
},
|
||||||
|
'/aibi/api/': {
|
||||||
|
target: 'http://127.0.0.1:9080',
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import { ISemantic } from '../../data';
|
|||||||
import { ColumnsConfig } from '../../components/TableColumnRender';
|
import { ColumnsConfig } from '../../components/TableColumnRender';
|
||||||
import ViewSearchFormModal from './ViewSearchFormModal';
|
import ViewSearchFormModal from './ViewSearchFormModal';
|
||||||
import { toDatasetEditPage } from '@/pages/SemanticModel/utils';
|
import { toDatasetEditPage } from '@/pages/SemanticModel/utils';
|
||||||
|
import UploadFile from './UploadFile';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
// dataSetList: ISemantic.IDatasetItem[];
|
// dataSetList: ISemantic.IDatasetItem[];
|
||||||
@@ -92,9 +93,6 @@ const DataSetTable: React.FC<Props> = ({ disabledEdit = false }) => {
|
|||||||
<a
|
<a
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
toDatasetEditPage(record.domainId, record.id, 'relation');
|
toDatasetEditPage(record.domainId, record.id, 'relation');
|
||||||
// setEditFormStep(1);
|
|
||||||
// setViewItem(record);
|
|
||||||
// setCreateDataSourceModalOpen(true);
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{name}
|
{name}
|
||||||
@@ -146,9 +144,6 @@ const DataSetTable: React.FC<Props> = ({ disabledEdit = false }) => {
|
|||||||
key="metricEditBtn"
|
key="metricEditBtn"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
toDatasetEditPage(record.domainId, record.id);
|
toDatasetEditPage(record.domainId, record.id);
|
||||||
// setEditFormStep(0);
|
|
||||||
// setViewItem(record);
|
|
||||||
// setCreateDataSourceModalOpen(true);
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
编辑
|
编辑
|
||||||
@@ -189,6 +184,12 @@ const DataSetTable: React.FC<Props> = ({ disabledEdit = false }) => {
|
|||||||
启用
|
启用
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
|
<UploadFile
|
||||||
|
key="uploadFile"
|
||||||
|
buttonType="link"
|
||||||
|
domainId={record.domainId}
|
||||||
|
datasetId={record.id}
|
||||||
|
/>
|
||||||
<Popconfirm
|
<Popconfirm
|
||||||
title="确认删除?"
|
title="确认删除?"
|
||||||
okText="是"
|
okText="是"
|
||||||
@@ -229,6 +230,13 @@ const DataSetTable: React.FC<Props> = ({ disabledEdit = false }) => {
|
|||||||
disabledEdit
|
disabledEdit
|
||||||
? [<></>]
|
? [<></>]
|
||||||
: [
|
: [
|
||||||
|
<UploadFile
|
||||||
|
key="uploadFile"
|
||||||
|
domainId={selectDomainId}
|
||||||
|
onFileUploaded={() => {
|
||||||
|
queryDataSetList();
|
||||||
|
}}
|
||||||
|
/>,
|
||||||
<Button
|
<Button
|
||||||
key="create"
|
key="create"
|
||||||
type="primary"
|
type="primary"
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
import { getToken } from '@/utils/utils';
|
||||||
|
import { UploadOutlined } from '@ant-design/icons';
|
||||||
|
import type { UploadProps } from 'antd';
|
||||||
|
import { Button, message, Upload } from 'antd';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
buttonType?: string;
|
||||||
|
domainId?: number;
|
||||||
|
datasetId?: string;
|
||||||
|
onFileUploaded?: () => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
const UploadFile = ({ buttonType, domainId, datasetId, onFileUploaded }: Props) => {
|
||||||
|
const props: UploadProps = {
|
||||||
|
name: 'multipartFile',
|
||||||
|
action: `/aibi/api/data/file/uploadFileNew?type=DATASET&domainId=${domainId}${
|
||||||
|
datasetId ? `&dataSetId=${datasetId}` : ''
|
||||||
|
}`,
|
||||||
|
showUploadList: false,
|
||||||
|
onChange(info) {
|
||||||
|
if (info.file.status !== 'uploading') {
|
||||||
|
console.log(info.file, info.fileList);
|
||||||
|
}
|
||||||
|
if (info.file.status === 'done') {
|
||||||
|
message.success('导入成功');
|
||||||
|
onFileUploaded?.();
|
||||||
|
} else if (info.file.status === 'error') {
|
||||||
|
message.error('导入失败');
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Upload {...props}>
|
||||||
|
{buttonType === 'link' ? (
|
||||||
|
<a>导入文件</a>
|
||||||
|
) : (
|
||||||
|
<Button icon={<UploadOutlined />}>导入文件</Button>
|
||||||
|
)}
|
||||||
|
</Upload>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default UploadFile;
|
||||||
28637
webapp/pnpm-lock.yaml
generated
28637
webapp/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user