mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][chat] Only vector retrieval is enabled in loose mode (#1899)
This commit is contained in:
@@ -20,22 +20,25 @@ import java.util.Objects;
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
// 1. query from embedding by queryText
|
||||
if (MapModeEnum.STRICT.equals(chatQueryContext.getRequest().getMapModeEnum())) {
|
||||
// Check if the map mode is LOOSE
|
||||
if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Query from embedding by queryText
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
|
||||
|
||||
// Process match results
|
||||
HanlpHelper.transLetterOriginal(matchResults);
|
||||
|
||||
// 2. build SchemaElementMatch by info
|
||||
// 2. Build SchemaElementMatch based on match results
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
|
||||
|
||||
// Skip if dataSetId is null
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
continue;
|
||||
}
|
||||
@@ -43,14 +46,19 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
|
||||
// Skip if schemaElement is null
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build SchemaElementMatch object
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(matchResult.getName()).similarity(matchResult.getSimilarity())
|
||||
.detectWord(matchResult.getDetectWord()).build();
|
||||
// 3. add to mapInfo
|
||||
|
||||
// 3. Add SchemaElementMatch to mapInfo
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,32 +33,32 @@ public class KeywordMapper extends BaseMapper {
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
String queryText = chatQueryContext.getRequest().getQueryText();
|
||||
// 1.hanlpDict Match
|
||||
|
||||
// 1. hanlpDict Match
|
||||
List<S2Term> terms =
|
||||
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy =
|
||||
ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
List<HanlpMapResult> hanlpMatchResults = getMatches(chatQueryContext, hanlpMatchStrategy);
|
||||
convertMapResultToMapInfo(hanlpMatchResults, chatQueryContext, terms);
|
||||
|
||||
List<HanlpMapResult> matchResults = getMatches(chatQueryContext, hanlpMatchStrategy);
|
||||
|
||||
convertHanlpMapResultToMapInfo(matchResults, chatQueryContext, terms);
|
||||
|
||||
// 2.database Match
|
||||
// 2. database Match
|
||||
DatabaseMatchStrategy databaseMatchStrategy =
|
||||
ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||
List<DatabaseMapResult> databaseResults =
|
||||
List<DatabaseMapResult> databaseMatchResults =
|
||||
getMatches(chatQueryContext, databaseMatchStrategy);
|
||||
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults);
|
||||
convertMapResultToMapInfo(chatQueryContext, databaseMatchResults);
|
||||
}
|
||||
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults,
|
||||
private void convertMapResultToMapInfo(List<HanlpMapResult> mapResults,
|
||||
ChatQueryContext chatQueryContext, List<S2Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
}
|
||||
|
||||
HanlpHelper.transLetterOriginal(mapResults);
|
||||
Map<String, Long> wordNatureToFrequency = terms.stream()
|
||||
.collect(Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||
Map<String, Long> wordNatureToFrequency =
|
||||
terms.stream().collect(Collectors.toMap(term -> term.getWord() + term.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||
@@ -74,9 +74,10 @@ public class KeywordMapper extends BaseMapper {
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
if (element == null) {
|
||||
if (Objects.isNull(element)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element).frequency(frequency).word(hanlpMapResult.getName())
|
||||
@@ -88,7 +89,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
}
|
||||
}
|
||||
|
||||
private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext,
|
||||
private void convertMapResultToMapInfo(ChatQueryContext chatQueryContext,
|
||||
List<DatabaseMapResult> mapResults) {
|
||||
for (DatabaseMapResult match : mapResults) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
|
||||
Reference in New Issue
Block a user