diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java index eb5d1134e..77024b855 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java @@ -20,22 +20,25 @@ import java.util.Objects; */ @Slf4j public class EmbeddingMapper extends BaseMapper { - - @Override public void doMap(ChatQueryContext chatQueryContext) { - // 1. query from embedding by queryText - if (MapModeEnum.STRICT.equals(chatQueryContext.getRequest().getMapModeEnum())) { + // Check if the map mode is LOOSE + if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) { return; } + + // 1. Query from embedding by queryText EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); List matchResults = getMatches(chatQueryContext, matchStrategy); + // Process match results HanlpHelper.transLetterOriginal(matchResults); - // 2. build SchemaElementMatch by info + // 2. Build SchemaElementMatch based on match results for (EmbeddingResult matchResult : matchResults) { Long elementId = Retrieval.getLongId(matchResult.getId()); Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId")); + + // Skip if dataSetId is null if (Objects.isNull(dataSetId)) { continue; } @@ -43,14 +46,19 @@ public class EmbeddingMapper extends BaseMapper { SchemaElementType.valueOf(matchResult.getMetadata().get("type")); SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId, chatQueryContext.getSemanticSchema()); + + // Skip if schemaElement is null if (schemaElement == null) { continue; } + + // Build SchemaElementMatch object SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() .element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY) .word(matchResult.getName()).similarity(matchResult.getSimilarity()) .detectWord(matchResult.getDetectWord()).build(); - // 3. add to mapInfo + + // 3. Add SchemaElementMatch to mapInfo addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java index 0af886413..cfce186d8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java @@ -33,32 +33,32 @@ public class KeywordMapper extends BaseMapper { @Override public void doMap(ChatQueryContext chatQueryContext) { String queryText = chatQueryContext.getRequest().getQueryText(); - // 1.hanlpDict Match + + // 1. hanlpDict Match List terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class); + List hanlpMatchResults = getMatches(chatQueryContext, hanlpMatchStrategy); + convertMapResultToMapInfo(hanlpMatchResults, chatQueryContext, terms); - List matchResults = getMatches(chatQueryContext, hanlpMatchStrategy); - - convertHanlpMapResultToMapInfo(matchResults, chatQueryContext, terms); - - // 2.database Match + // 2. database Match DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class); - List databaseResults = + List databaseMatchResults = getMatches(chatQueryContext, databaseMatchStrategy); - convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults); + convertMapResultToMapInfo(chatQueryContext, databaseMatchResults); } - private void convertHanlpMapResultToMapInfo(List mapResults, + private void convertMapResultToMapInfo(List mapResults, ChatQueryContext chatQueryContext, List terms) { if (CollectionUtils.isEmpty(mapResults)) { return; } + HanlpHelper.transLetterOriginal(mapResults); - Map wordNatureToFrequency = terms.stream() - .collect(Collectors.toMap(entry -> entry.getWord() + entry.getNature(), + Map wordNatureToFrequency = + terms.stream().collect(Collectors.toMap(term -> term.getWord() + term.getNature(), term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2)); for (HanlpMapResult hanlpMapResult : mapResults) { @@ -74,9 +74,10 @@ public class KeywordMapper extends BaseMapper { Long elementID = NatureHelper.getElementID(nature); SchemaElement element = getSchemaElement(dataSetId, elementType, elementID, chatQueryContext.getSemanticSchema()); - if (element == null) { + if (Objects.isNull(element)) { continue; } + Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature); SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() .element(element).frequency(frequency).word(hanlpMapResult.getName()) @@ -88,7 +89,7 @@ public class KeywordMapper extends BaseMapper { } } - private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext, + private void convertMapResultToMapInfo(ChatQueryContext chatQueryContext, List mapResults) { for (DatabaseMapResult match : mapResults) { SchemaElement schemaElement = match.getSchemaElement();