[improvement][chat] Only vector retrieval is enabled in loose mode (#1899)

This commit is contained in:
lexluo09
2024-11-10 10:39:17 +08:00
committed by GitHub
parent e0e167fd40
commit ca4545bb15
2 changed files with 28 additions and 19 deletions

View File

@@ -20,22 +20,25 @@ import java.util.Objects;
*/ */
@Slf4j @Slf4j
public class EmbeddingMapper extends BaseMapper { public class EmbeddingMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) { public void doMap(ChatQueryContext chatQueryContext) {
// 1. query from embedding by queryText // Check if the map mode is LOOSE
if (MapModeEnum.STRICT.equals(chatQueryContext.getRequest().getMapModeEnum())) { if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) {
return; return;
} }
// 1. Query from embedding by queryText
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy); List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
// Process match results
HanlpHelper.transLetterOriginal(matchResults); HanlpHelper.transLetterOriginal(matchResults);
// 2. build SchemaElementMatch by info // 2. Build SchemaElementMatch based on match results
for (EmbeddingResult matchResult : matchResults) { for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId()); Long elementId = Retrieval.getLongId(matchResult.getId());
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId")); Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
// Skip if dataSetId is null
if (Objects.isNull(dataSetId)) { if (Objects.isNull(dataSetId)) {
continue; continue;
} }
@@ -43,14 +46,19 @@ public class EmbeddingMapper extends BaseMapper {
SchemaElementType.valueOf(matchResult.getMetadata().get("type")); SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId, SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
chatQueryContext.getSemanticSchema()); chatQueryContext.getSemanticSchema());
// Skip if schemaElement is null
if (schemaElement == null) { if (schemaElement == null) {
continue; continue;
} }
// Build SchemaElementMatch object
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY) .element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(matchResult.getName()).similarity(matchResult.getSimilarity()) .word(matchResult.getName()).similarity(matchResult.getSimilarity())
.detectWord(matchResult.getDetectWord()).build(); .detectWord(matchResult.getDetectWord()).build();
// 3. add to mapInfo
// 3. Add SchemaElementMatch to mapInfo
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
} }
} }

View File

@@ -33,32 +33,32 @@ public class KeywordMapper extends BaseMapper {
@Override @Override
public void doMap(ChatQueryContext chatQueryContext) { public void doMap(ChatQueryContext chatQueryContext) {
String queryText = chatQueryContext.getRequest().getQueryText(); String queryText = chatQueryContext.getRequest().getQueryText();
// 1.hanlpDict Match
// 1. hanlpDict Match
List<S2Term> terms = List<S2Term> terms =
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
HanlpDictMatchStrategy hanlpMatchStrategy = HanlpDictMatchStrategy hanlpMatchStrategy =
ContextUtils.getBean(HanlpDictMatchStrategy.class); ContextUtils.getBean(HanlpDictMatchStrategy.class);
List<HanlpMapResult> hanlpMatchResults = getMatches(chatQueryContext, hanlpMatchStrategy);
convertMapResultToMapInfo(hanlpMatchResults, chatQueryContext, terms);
List<HanlpMapResult> matchResults = getMatches(chatQueryContext, hanlpMatchStrategy); // 2. database Match
convertHanlpMapResultToMapInfo(matchResults, chatQueryContext, terms);
// 2.database Match
DatabaseMatchStrategy databaseMatchStrategy = DatabaseMatchStrategy databaseMatchStrategy =
ContextUtils.getBean(DatabaseMatchStrategy.class); ContextUtils.getBean(DatabaseMatchStrategy.class);
List<DatabaseMapResult> databaseResults = List<DatabaseMapResult> databaseMatchResults =
getMatches(chatQueryContext, databaseMatchStrategy); getMatches(chatQueryContext, databaseMatchStrategy);
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults); convertMapResultToMapInfo(chatQueryContext, databaseMatchResults);
} }
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, private void convertMapResultToMapInfo(List<HanlpMapResult> mapResults,
ChatQueryContext chatQueryContext, List<S2Term> terms) { ChatQueryContext chatQueryContext, List<S2Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) { if (CollectionUtils.isEmpty(mapResults)) {
return; return;
} }
HanlpHelper.transLetterOriginal(mapResults); HanlpHelper.transLetterOriginal(mapResults);
Map<String, Long> wordNatureToFrequency = terms.stream() Map<String, Long> wordNatureToFrequency =
.collect(Collectors.toMap(entry -> entry.getWord() + entry.getNature(), terms.stream().collect(Collectors.toMap(term -> term.getWord() + term.getNature(),
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2)); term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
for (HanlpMapResult hanlpMapResult : mapResults) { for (HanlpMapResult hanlpMapResult : mapResults) {
@@ -74,9 +74,10 @@ public class KeywordMapper extends BaseMapper {
Long elementID = NatureHelper.getElementID(nature); Long elementID = NatureHelper.getElementID(nature);
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID, SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
chatQueryContext.getSemanticSchema()); chatQueryContext.getSemanticSchema());
if (element == null) { if (Objects.isNull(element)) {
continue; continue;
} }
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature); Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element).frequency(frequency).word(hanlpMapResult.getName()) .element(element).frequency(frequency).word(hanlpMapResult.getName())
@@ -88,7 +89,7 @@ public class KeywordMapper extends BaseMapper {
} }
} }
private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext, private void convertMapResultToMapInfo(ChatQueryContext chatQueryContext,
List<DatabaseMapResult> mapResults) { List<DatabaseMapResult> mapResults) {
for (DatabaseMapResult match : mapResults) { for (DatabaseMapResult match : mapResults) {
SchemaElement schemaElement = match.getSchemaElement(); SchemaElement schemaElement = match.getSchemaElement();