[improvement][chat] Only vector retrieval is enabled in loose mode (#1899)

This commit is contained in:
lexluo09
2024-11-10 10:39:17 +08:00
committed by GitHub
parent e0e167fd40
commit ca4545bb15
2 changed files with 28 additions and 19 deletions

View File

@@ -20,22 +20,25 @@ import java.util.Objects;
*/
@Slf4j
public class EmbeddingMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) {
// 1. query from embedding by queryText
if (MapModeEnum.STRICT.equals(chatQueryContext.getRequest().getMapModeEnum())) {
// Check if the map mode is LOOSE
if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) {
return;
}
// 1. Query from embedding by queryText
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
// Process match results
HanlpHelper.transLetterOriginal(matchResults);
// 2. build SchemaElementMatch by info
// 2. Build SchemaElementMatch based on match results
for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId());
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
// Skip if dataSetId is null
if (Objects.isNull(dataSetId)) {
continue;
}
@@ -43,14 +46,19 @@ public class EmbeddingMapper extends BaseMapper {
SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
chatQueryContext.getSemanticSchema());
// Skip if schemaElement is null
if (schemaElement == null) {
continue;
}
// Build SchemaElementMatch object
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(matchResult.getName()).similarity(matchResult.getSimilarity())
.detectWord(matchResult.getDetectWord()).build();
// 3. add to mapInfo
// 3. Add SchemaElementMatch to mapInfo
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}

View File

@@ -33,32 +33,32 @@ public class KeywordMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) {
String queryText = chatQueryContext.getRequest().getQueryText();
// 1.hanlpDict Match
// 1. hanlpDict Match
List<S2Term> terms =
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
HanlpDictMatchStrategy hanlpMatchStrategy =
ContextUtils.getBean(HanlpDictMatchStrategy.class);
List<HanlpMapResult> hanlpMatchResults = getMatches(chatQueryContext, hanlpMatchStrategy);
convertMapResultToMapInfo(hanlpMatchResults, chatQueryContext, terms);
List<HanlpMapResult> matchResults = getMatches(chatQueryContext, hanlpMatchStrategy);
convertHanlpMapResultToMapInfo(matchResults, chatQueryContext, terms);
// 2.database Match
// 2. database Match
DatabaseMatchStrategy databaseMatchStrategy =
ContextUtils.getBean(DatabaseMatchStrategy.class);
List<DatabaseMapResult> databaseResults =
List<DatabaseMapResult> databaseMatchResults =
getMatches(chatQueryContext, databaseMatchStrategy);
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults);
convertMapResultToMapInfo(chatQueryContext, databaseMatchResults);
}
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults,
private void convertMapResultToMapInfo(List<HanlpMapResult> mapResults,
ChatQueryContext chatQueryContext, List<S2Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {
return;
}
HanlpHelper.transLetterOriginal(mapResults);
Map<String, Long> wordNatureToFrequency = terms.stream()
.collect(Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
Map<String, Long> wordNatureToFrequency =
terms.stream().collect(Collectors.toMap(term -> term.getWord() + term.getNature(),
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
for (HanlpMapResult hanlpMapResult : mapResults) {
@@ -74,9 +74,10 @@ public class KeywordMapper extends BaseMapper {
Long elementID = NatureHelper.getElementID(nature);
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
chatQueryContext.getSemanticSchema());
if (element == null) {
if (Objects.isNull(element)) {
continue;
}
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element).frequency(frequency).word(hanlpMapResult.getName())
@@ -88,7 +89,7 @@ public class KeywordMapper extends BaseMapper {
}
}
private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext,
private void convertMapResultToMapInfo(ChatQueryContext chatQueryContext,
List<DatabaseMapResult> mapResults) {
for (DatabaseMapResult match : mapResults) {
SchemaElement schemaElement = match.getSchemaElement();