diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java index 42f30000f..d6117ea56 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java @@ -1,10 +1,6 @@ package com.tencent.supersonic.headless.api.pojo; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.ToString; +import lombok.*; import java.io.Serializable; @@ -21,6 +17,7 @@ public class SchemaElementMatch implements Serializable { private String word; private Long frequency; private boolean isInherited; + private boolean llmMatched; public boolean isFullMatched() { return 1.0 == similarity; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java index 569892af9..85c35540b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/EmbeddingResult.java @@ -13,6 +13,7 @@ public class EmbeddingResult extends MapResult { private String id; private Map metadata; + private boolean llmMatched; @Override public boolean equals(Object o) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java index eadcb7375..ff778ee46 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java @@ -1,9 +1,12 @@ package com.tencent.supersonic.headless.chat.mapper; +import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; @@ -11,6 +14,7 @@ import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; import dev.langchain4j.store.embedding.Retrieval; import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; import java.util.List; import java.util.Objects; @@ -23,10 +27,16 @@ public class EmbeddingMapper extends BaseMapper { @Override public boolean accept(ChatQueryContext chatQueryContext) { - return MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum()); + boolean b0 = MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum()); + boolean b1 = chatQueryContext.getRequest().getText2SQLType() == Text2SQLType.LLM_OR_RULE; + return b0 || b1; } public void doMap(ChatQueryContext chatQueryContext) { + + // TODO: 如果是在LOOSE执行过了,那么在LLM_OR_RULE阶段可以不用执行,所以这里缺乏一个状态来传递,暂时先忽略这个浪费行为吧 + SchemaMapInfo mappedInfo = chatQueryContext.getMapInfo(); + // 1. Query from embedding by queryText EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); List matchResults = getMatches(chatQueryContext, matchStrategy); @@ -53,15 +63,26 @@ public class EmbeddingMapper extends BaseMapper { continue; } + // Build SchemaElementMatch object SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() .element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY) .word(matchResult.getName()).similarity(matchResult.getSimilarity()) .detectWord(matchResult.getDetectWord()).build(); + schemaElementMatch.setLlmMatched(matchResult.isLlmMatched()); // 3. Add SchemaElementMatch to mapInfo addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); } + if (CollectionUtils.isEmpty(matchResults)) { + log.info("embedding mapper no match"); + } else { + for (EmbeddingResult matchResult : matchResults) { + log.info("embedding match name=[{}],detectWord=[{}],similarity=[{}],metadata=[{}]", + matchResult.getName(), matchResult.getDetectWord(), + matchResult.getSimilarity(), JsonUtil.toString(matchResult.getMetadata())); + } + } } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index e6e52c5d4..618ed2d1b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -1,9 +1,17 @@ package com.tencent.supersonic.headless.chat.mapper; +import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; +import com.hankcs.hanlp.seg.common.Term; +import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; +import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.provider.ModelProvider; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; @@ -14,18 +22,12 @@ import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER; -import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER; -import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD; +import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*; /** * EmbeddingMatchStrategy uses vector database to perform similarity search against the embeddings @@ -35,37 +37,165 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING @Slf4j public class EmbeddingMatchStrategy extends BatchMatchStrategy { + @Autowired + protected MapperConfig mapperConfig; + @Autowired private MetaEmbeddingService metaEmbeddingService; + private static final String LLM_FILTER_PROMPT = + """ + \ + #Role: You are a professional data analyst specializing in metrics and dimensions. + #Task: Given a user query and a list of retrieved metrics/dimensions through vector recall, + please analyze which metrics/dimensions the user is most likely interested in. + #Rules: + 1. Based on user query and retrieved info, accurately determine metrics/dimensions user truly cares about. + 2. Do not return all retrieved info, only select those highly relevant to user query. + 3. Maintain high quality output, exclude metrics/dimensions irrelevant to user intent. + 4. Output must be in JSON array format, only include IDs from retrieved info, e.g.: ['id1', 'id2'] + 5. Return JSON content directly without markdown formatting + #Input Example: + #User Query: {{userText}} + #Retrieved Metrics/Dimensions: {{retrievedInfo}} + #Output:"""; + + @Override + public List detect(ChatQueryContext chatQueryContext, List terms, + Set detectDataSetIds) { + if (chatQueryContext == null || CollectionUtils.isEmpty(detectDataSetIds)) { + log.warn("Invalid input parameters: context={}, dataSetIds={}", chatQueryContext, + detectDataSetIds); + return Collections.emptyList(); + } + + // 1. Base detection + List baseResults = super.detect(chatQueryContext, terms, detectDataSetIds); + + boolean useLLM = Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM)); + + // 2. LLM enhanced detection + if (useLLM) { + List llmResults = detectWithLLM(chatQueryContext, detectDataSetIds); + if (!CollectionUtils.isEmpty(llmResults)) { + baseResults.addAll(llmResults); + } + } + + // 3. Deduplicate results + return baseResults.stream().distinct().collect(Collectors.toList()); + } + + /** + * Perform enhanced detection using LLM + */ + private List detectWithLLM(ChatQueryContext chatQueryContext, + Set detectDataSetIds) { + try { + String queryText = chatQueryContext.getRequest().getQueryText(); + if (StringUtils.isBlank(queryText)) { + return Collections.emptyList(); + } + + // Get segmentation results + Set detectSegments = extractValidSegments(queryText); + if (CollectionUtils.isEmpty(detectSegments)) { + log.info("No valid segments found for text: {}", queryText); + return Collections.emptyList(); + } + + return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, true); + } catch (Exception e) { + log.error("Error in LLM detection for context: {}", chatQueryContext, e); + return Collections.emptyList(); + } + } + + /** + * Extract valid word segments by filtering out unwanted word natures + */ + private Set extractValidSegments(String text) { + List natureList = Arrays.asList(StringUtils.split(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE ), ",")); + return HanlpHelper.getSegment().seg(text).stream() + .filter(t -> natureList.stream().noneMatch(nature -> t.nature.startsWith(nature))) + .map(Term::getWord).collect(Collectors.toSet()); + } + @Override public List detectByBatch(ChatQueryContext chatQueryContext, Set detectDataSetIds, Set detectSegments) { + return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, false); + } + + /** + * Process detection in batches with LLM option + * + * @param chatQueryContext The context of the chat query + * @param detectDataSetIds Target dataset IDs for detection + * @param detectSegments Segments to be detected + * @param useLlm Whether to use LLM for filtering results + * @return List of embedding results + */ + public List detectByBatch(ChatQueryContext chatQueryContext, + Set detectDataSetIds, Set detectSegments, boolean useLlm) { Set results = ConcurrentHashMap.newKeySet(); int embeddingMapperBatch = Integer .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); - List queryTextsList = - detectSegments.stream().map(detectSegment -> detectSegment.trim()) - .filter(detectSegment -> StringUtils.isNotBlank(detectSegment)) - .collect(Collectors.toList()); + // Process and filter query texts + List queryTextsList = detectSegments.stream().map(String::trim) + .filter(StringUtils::isNotBlank).collect(Collectors.toList()); + // Partition queries into sub-lists for batch processing List> queryTextsSubList = Lists.partition(queryTextsList, embeddingMapperBatch); + // Create and execute tasks for each batch List> tasks = new ArrayList<>(); for (List queryTextsSub : queryTextsSubList) { - tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results)); + tasks.add( + createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results, useLlm)); } executeTasks(tasks); + + // Apply LLM filtering if enabled + if (useLlm) { + Map variable = new HashMap<>(); + variable.put("userText", chatQueryContext.getRequest().getQueryText()); + variable.put("retrievedInfo", JSONObject.toJSONString(results)); + + Prompt prompt = PromptTemplate.from(LLM_FILTER_PROMPT).apply(variable); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(); + String response = chatLanguageModel.generate(prompt.toUserMessage().singleText()); + + if (StringUtils.isBlank(response)) { + results.clear(); + } else { + List retrievedIds = JSONObject.parseArray(response, String.class); + results = results.stream().filter(t -> retrievedIds.contains(t.getId())) + .collect(Collectors.toSet()); + results.forEach(r -> r.setLlmMatched(true)); + } + } + return new ArrayList<>(results); } + /** + * Create a task for batch processing + * + * @param chatQueryContext The context of the chat query + * @param detectDataSetIds Target dataset IDs + * @param queryTextsSub Sub-list of query texts to process + * @param results Shared result set for collecting results + * @param useLlm Whether to use LLM + * @return Callable task + */ private Callable createTask(ChatQueryContext chatQueryContext, Set detectDataSetIds, - List queryTextsSub, Set results) { + List queryTextsSub, Set results, boolean useLlm) { return () -> { - List oneRoundResults = - detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext); + List oneRoundResults = detectByQueryTextsSub(detectDataSetIds, + queryTextsSub, chatQueryContext, useLlm); synchronized (results) { selectResultInOneRound(results, oneRoundResults); } @@ -73,57 +203,73 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy }; } + /** + * Process a sub-list of query texts + * + * @param detectDataSetIds Target dataset IDs + * @param queryTextsSub Sub-list of query texts + * @param chatQueryContext Chat query context + * @param useLlm Whether to use LLM + * @return List of embedding results for this batch + */ private List detectByQueryTextsSub(Set detectDataSetIds, - List queryTextsSub, ChatQueryContext chatQueryContext) { + List queryTextsSub, ChatQueryContext chatQueryContext, boolean useLlm) { Map> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds(); + + // Get configuration parameters double threshold = - Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD)); - - // step1. build query params - RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); - - // step2. retrieveQuery by detectSegment + Double.parseDouble(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD)); int embeddingNumber = - Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER)); + Integer.parseInt(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER)); + int embeddingRoundNumber = + Integer.parseInt(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER)); + + // Build and execute query + RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); List retrieveQueryResults = metaEmbeddingService.retrieveQuery( retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds); if (CollectionUtils.isEmpty(retrieveQueryResults)) { - return new ArrayList<>(); + return Collections.emptyList(); } - // step3. build EmbeddingResults - List collect = retrieveQueryResults.stream().map(retrieveQueryResult -> { - List retrievals = retrieveQueryResult.getRetrieval(); - if (CollectionUtils.isNotEmpty(retrievals)) { - retrievals.removeIf(retrieval -> { - if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) { - return retrieval.getSimilarity() < threshold; - } - return false; - }); + + // Process results + List collect = retrieveQueryResults.stream().peek(result -> { + if (!useLlm && CollectionUtils.isNotEmpty(result.getRetrieval())) { + result.getRetrieval() + .removeIf(retrieval -> !result.getQuery().contains(retrieval.getQuery()) + && retrieval.getSimilarity() < threshold); } - return retrieveQueryResult; - }).filter(retrieveQueryResult -> CollectionUtils - .isNotEmpty(retrieveQueryResult.getRetrieval())) - .flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream() - .map(retrieval -> { - EmbeddingResult embeddingResult = new EmbeddingResult(); - BeanUtils.copyProperties(retrieval, embeddingResult); - embeddingResult.setDetectWord(retrieveQueryResult.getQuery()); - embeddingResult.setName(retrieval.getQuery()); - Map convertedMap = retrieval.getMetadata().entrySet() - .stream().collect(Collectors.toMap(Map.Entry::getKey, - entry -> entry.getValue().toString())); - embeddingResult.setMetadata(convertedMap); - return embeddingResult; - })) + }).filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval())) + .flatMap(result -> result.getRetrieval().stream() + .map(retrieval -> convertToEmbeddingResult(result, retrieval))) .collect(Collectors.toList()); - // step4. select mapResul in one round - int embeddingRoundNumber = - Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER)); - int roundNumber = embeddingRoundNumber * queryTextsSub.size(); - return collect.stream().sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity)) - .limit(roundNumber).collect(Collectors.toList()); + // Sort and limit results + return collect.stream() + .sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity).reversed()) + .limit(embeddingRoundNumber * queryTextsSub.size()).collect(Collectors.toList()); + } + + /** + * Convert RetrieveQueryResult and Retrieval to EmbeddingResult + * + * @param queryResult The query result containing retrieval information + * @param retrieval The retrieval data to be converted + * @return Converted EmbeddingResult + */ + private EmbeddingResult convertToEmbeddingResult(RetrieveQueryResult queryResult, + Retrieval retrieval) { + EmbeddingResult embeddingResult = new EmbeddingResult(); + BeanUtils.copyProperties(retrieval, embeddingResult); + embeddingResult.setDetectWord(queryResult.getQuery()); + embeddingResult.setName(retrieval.getQuery()); + + // Convert metadata to string values + Map metadata = retrieval.getMetadata().entrySet().stream().collect( + Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue()))); + embeddingResult.setMetadata(metadata); + + return embeddingResult; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java index 5fa6df097..b16a6f186 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java @@ -7,12 +7,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -66,7 +61,7 @@ public class MapFilter { List value = entry.getValue(); if (!CollectionUtils.isEmpty(value)) { value.removeIf(schemaElementMatch -> StringUtils - .length(schemaElementMatch.getDetectWord()) <= 1); + .length(schemaElementMatch.getDetectWord()) <= 1 && !schemaElementMatch.isLlmMatched()); } } } @@ -85,7 +80,7 @@ public class MapFilter { } public static void filterByQueryDataType(ChatQueryContext chatQueryContext, - Predicate needRemovePredicate) { + Predicate needRemovePredicate) { Map> dataSetElementMatches = chatQueryContext.getMapInfo().getDataSetElementMatches(); for (Map.Entry> entry : dataSetElementMatches.entrySet()) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java index 583071c84..cd4b6772c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java @@ -57,4 +57,12 @@ public class MapperConfig extends ParameterConfig { public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER = new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_USE_LLM = + new Parameter("s2.mapper.embedding.use-llm-enhance", "false", "使用LLM对召回的向量进行二次判断开关", + "embedding的结果再通过一次LLM来筛选,这时候忽略各个向量阀值", "bool", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE = + new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']", "使用LLM召回二次处理时对问题分词词性的控制", + "分词后允许的词性才会进行向量召回", "list", "Mapper相关配置"); } diff --git a/launchers/standalone/src/main/resources/s2-config.yaml b/launchers/standalone/src/main/resources/s2-config.yaml index fb81d4bd8..133a78e5d 100644 --- a/launchers/standalone/src/main/resources/s2-config.yaml +++ b/launchers/standalone/src/main/resources/s2-config.yaml @@ -41,3 +41,5 @@ s2: threshold: 0.5 min: threshold: 0.3 + embedding: + use-llm-enhance: true