[improvement](chat) Optimize the Schema Mapping rules (#1649)

This commit is contained in:
lexluo09
2024-09-10 22:58:34 +08:00
committed by GitHub
parent 183caf7931
commit 879696e493
20 changed files with 202 additions and 122 deletions

View File

@@ -51,13 +51,13 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
continue;
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance();
double score = parseContext.getQueryText().length() * (1 - distance);
double similarity = embeddingRetrieval.getSimilarity();
double score = parseContext.getQueryText().length() * similarity;
return PluginRecallResult.builder()
.plugin(plugin)
.dataSetIds(dataSetList)
.score(score)
.distance(distance)
.distance(similarity)
.build();
}
}
@@ -73,7 +73,9 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
embeddingRetrievals =
embeddingRetrievals.stream()
.sorted(Comparator.comparingDouble(o -> Math.abs(o.getDistance())))
.sorted(
Comparator.comparingDouble(
o -> Math.abs(o.getSimilarity())))
.collect(Collectors.toList());
embeddingResp.setRetrieval(embeddingRetrievals);
}

View File

@@ -65,7 +65,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
List<Retrieval> retrievals =
retrieveQueryResults.stream()
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
.sorted(Comparator.comparingDouble(Retrieval::getDistance))
.sorted(Comparator.comparingDouble(Retrieval::getSimilarity))
.distinct()
.collect(Collectors.toList());
Set<Long> metricIds =

View File

@@ -164,7 +164,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
List<Retrieval> retrievals =
result.matches().stream()
.map(this::convertToRetrieval)
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
.sorted(Comparator.comparingDouble(Retrieval::getSimilarity))
.limit(num)
.collect(Collectors.toList());
@@ -177,7 +177,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
private Retrieval convertToRetrieval(EmbeddingMatch<TextSegment> embeddingMatch) {
Retrieval retrieval = new Retrieval();
TextSegment embedded = embeddingMatch.embedded();
retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setSimilarity(embeddingMatch.score());
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text());

View File

@@ -12,7 +12,7 @@ public class Retrieval {
protected String id;
protected double distance;
protected double similarity;
protected String query;
@@ -35,7 +35,7 @@ public class Retrieval {
return false;
}
Retrieval retrieval = (Retrieval) o;
return Double.compare(retrieval.distance, distance) == 0
return Double.compare(retrieval.similarity, similarity) == 0
&& Objects.equal(id, retrieval.id)
&& Objects.equal(query, retrieval.query)
&& Objects.equal(metadata, retrieval.metadata);
@@ -43,6 +43,6 @@ public class Retrieval {
@Override
public int hashCode() {
return Objects.hashCode(id, distance, query, metadata);
return Objects.hashCode(id, similarity, query, metadata);
}
}

View File

@@ -12,8 +12,8 @@ import lombok.ToString;
@AllArgsConstructor
@NoArgsConstructor
public class SchemaElementMatch {
SchemaElement element;
double offset;
double similarity;
String detectWord;
String word;

View File

@@ -12,9 +12,6 @@ import java.util.Map;
public class EmbeddingResult extends MapResult {
private String id;
private double distance;
private Map<String, String> metadata;
@Override

View File

@@ -10,16 +10,13 @@ import java.util.List;
@Data
@ToString
public class HanlpMapResult extends MapResult {
private List<String> natures;
private int offset = 0;
private double similarity;
public HanlpMapResult(String name, List<String> natures, String detectWord) {
public HanlpMapResult(String name, List<String> natures, String detectWord, double similarity) {
this.name = name;
this.natures = natures;
this.detectWord = detectWord;
this.similarity = similarity;
}
@Override

View File

@@ -10,16 +10,17 @@ import java.io.Serializable;
public abstract class MapResult implements Serializable {
protected String name;
protected int offset;
protected int index;
protected String detectWord;
protected double similarity;
public abstract String getMapKey();
public Boolean lessSimilar(MapResult otherResult) {
String mapKey = this.getMapKey();
String otherMapKey = otherResult.getMapKey();
return mapKey.equals(otherMapKey)
&& otherResult.getDetectWord().length() < otherResult.getDetectWord().length();
return mapKey.equals(otherMapKey) && otherResult.similarity < otherResult.similarity;
}
}

View File

@@ -8,6 +8,7 @@ import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
@@ -62,7 +63,9 @@ public class SearchService {
.map(
entry -> {
String name = entry.getKey().replace("#", " ");
return new HanlpMapResult(name, entry.getValue(), key);
double similarity = EditDistanceUtils.getSimilarity(name, key);
return new HanlpMapResult(
name, entry.getValue(), key, similarity);
})
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toList());
@@ -109,8 +112,10 @@ public class SearchService {
.getType(),
""))
.collect(Collectors.toList());
name = StringUtils.reverse(name);
return new HanlpMapResult(name, natures, key);
double similarity = EditDistanceUtils.getSimilarity(name, key);
return new HanlpMapResult(name, natures, key, similarity);
})
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toList());

View File

@@ -270,7 +270,10 @@ public class HanlpHelper {
if (orig != null) {
MapResult addMapResult =
new HanlpMapResult(
orig, Arrays.asList(nature), hanlpMapResult.getDetectWord());
orig,
Arrays.asList(nature),
hanlpMapResult.getDetectWord(),
hanlpMapResult.getSimilarity());
mapResults.add((T) addMapResult);
isAdd = true;
}
@@ -301,7 +304,7 @@ public class HanlpHelper {
addMapResult.setDetectWord(embeddingResult.getDetectWord());
addMapResult.setId(embeddingResult.getId());
addMapResult.setMetadata(embeddingResult.getMetadata());
addMapResult.setDistance(embeddingResult.getDistance());
addMapResult.setSimilarity(embeddingResult.getSimilarity());
mapResults.add((T) addMapResult);
isAdd = true;
}

View File

@@ -37,7 +37,7 @@ public abstract class BaseMapper implements SchemaMapper {
try {
doMap(chatQueryContext);
filter(chatQueryContext);
MapFilter.filter(chatQueryContext);
} catch (Exception e) {
log.error("work error", e);
}
@@ -147,33 +147,4 @@ public abstract class BaseMapper implements SchemaMapper {
}
return matches;
}
private void filter(ChatQueryContext chatQueryContext) {
MapFilter.filterByDataSetId(chatQueryContext);
MapFilter.filterByDetectWordLenLessThanOne(chatQueryContext);
switch (chatQueryContext.getQueryDataType()) {
case TAG:
MapFilter.filterByQueryDataType(
chatQueryContext, element -> !(element.getIsTag() > 0));
break;
case METRIC:
MapFilter.filterByQueryDataType(
chatQueryContext,
element -> !SchemaElementType.METRIC.equals(element.getType()));
break;
case DIMENSION:
MapFilter.filterByQueryDataType(
chatQueryContext,
element -> {
boolean isDimensionOrValue =
SchemaElementType.DIMENSION.equals(element.getType())
|| SchemaElementType.VALUE.equals(element.getType());
return !isDimensionOrValue;
});
break;
case ALL:
default:
break;
}
}
}

View File

@@ -49,7 +49,7 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
boolean isDeleted =
existResults.removeIf(
existResult -> {
boolean delete = oneRoundResult.lessSimilar(existResult);
boolean delete = existResult.lessSimilar(oneRoundResult);
if (delete) {
log.info("deleted existResult:{}", existResult);
}

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
@@ -49,9 +50,8 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
List<DatabaseMapResult> results = new ArrayList<>();
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
String name = entry.getKey();
if (!name.contains(detectSegment)
|| mapperHelper.getSimilarity(detectSegment, name)
< metricDimensionThresholdConfig) {
double similarity = EditDistanceUtils.getSimilarity(detectSegment, name);
if (!name.contains(detectSegment) || similarity < metricDimensionThresholdConfig) {
continue;
}
Set<SchemaElement> schemaElements = entry.getValue();
@@ -68,6 +68,7 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
databaseMapResult.setDetectWord(detectSegment);
databaseMapResult.setName(schemaElement.getName());
databaseMapResult.setSimilarity(similarity);
databaseMapResult.setSchemaElement(schemaElement);
results.add(databaseMapResult);
}

View File

@@ -49,7 +49,7 @@ public class EmbeddingMapper extends BaseMapper {
.element(schemaElement)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(matchResult.getName())
.similarity(1 - matchResult.getDistance())
.similarity(matchResult.getSimilarity())
.detectWord(matchResult.getDetectWord())
.build();
// 3. add to mapInfo

View File

@@ -57,13 +57,14 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
Lists.partition(queryTextsList, embeddingMapperBatch);
for (List<String> queryTextsSub : queryTextsSubList) {
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
List<EmbeddingResult> oneRoundResults =
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
selectResultInOneRound(results, oneRoundResults);
}
return new ArrayList<>(results);
}
private void detectByQueryTextsSub(
Set<EmbeddingResult> results,
private List<EmbeddingResult> detectByQueryTextsSub(
Set<Long> detectDataSetIds,
List<String> queryTextsSub,
ChatQueryContext chatQueryContext) {
@@ -89,7 +90,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
return new ArrayList<>();
}
// step3. build EmbeddingResults
List<EmbeddingResult> collect =
@@ -103,8 +104,8 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
if (!retrieveQueryResult
.getQuery()
.contains(retrieval.getQuery())) {
return retrieval.getDistance()
> 1 - threshold;
return retrieval.getSimilarity()
< threshold;
}
return false;
});
@@ -150,11 +151,9 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
int embeddingRoundNumber =
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
List<EmbeddingResult> oneRoundResults =
collect.stream()
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
.limit(roundNumber)
.collect(Collectors.toList());
selectResultInOneRound(results, oneRoundResults);
return collect.stream()
.sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity))
.limit(roundNumber)
.collect(Collectors.toList());
}
}

View File

@@ -72,10 +72,15 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
hanlpMapResults.stream()
.filter(
term ->
mapperHelper.getSimilarity(detectSegment, term.getName())
term.getSimilarity()
>= getThresholdMatch(
term.getNatures(), chatQueryContext))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.map(
parseResult -> {
parseResult.setOffset(offset);
return parseResult;
})
.collect(Collectors.toCollection(LinkedHashSet::new));
log.debug(
@@ -83,18 +88,6 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
detectSegment,
hanlpMapResults);
hanlpMapResults =
hanlpMapResults.stream()
.map(
parseResult -> {
parseResult.setOffset(offset);
parseResult.setSimilarity(
mapperHelper.getSimilarity(
detectSegment, parseResult.getName()));
return parseResult;
})
.collect(Collectors.toCollection(LinkedHashSet::new));
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
int oneDetectionValueSize =
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));

View File

@@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@@ -103,7 +104,6 @@ public class KeywordMapper extends BaseMapper {
private void convertDatabaseMapResultToMapInfo(
ChatQueryContext chatQueryContext, List<DatabaseMapResult> mapResults) {
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
for (DatabaseMapResult match : mapResults) {
SchemaElement schemaElement = match.getSchemaElement();
Set<Long> regElementSet =
@@ -118,7 +118,7 @@ public class KeywordMapper extends BaseMapper {
.detectWord(match.getDetectWord())
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.similarity(
mapperHelper.getSimilarity(
EditDistanceUtils.getSimilarity(
match.getDetectWord(), schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);

View File

@@ -7,14 +7,46 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
public class MapFilter {
public static void filter(ChatQueryContext chatQueryContext) {
filterByDataSetId(chatQueryContext);
filterByDetectWordLenLessThanOne(chatQueryContext);
switch (chatQueryContext.getQueryDataType()) {
case TAG:
filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0));
break;
case METRIC:
filterByQueryDataType(
chatQueryContext,
element -> !SchemaElementType.METRIC.equals(element.getType()));
break;
case DIMENSION:
filterByQueryDataType(
chatQueryContext,
element -> {
boolean isDimensionOrValue =
SchemaElementType.DIMENSION.equals(element.getType())
|| SchemaElementType.VALUE.equals(element.getType());
return !isDimensionOrValue;
});
break;
case ALL:
default:
break;
}
filterByRules(chatQueryContext);
}
public static void filterByDataSetId(ChatQueryContext chatQueryContext) {
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
if (CollectionUtils.isEmpty(dataSetIds)) {
@@ -44,25 +76,93 @@ public class MapFilter {
public static void filterByQueryDataType(
ChatQueryContext chatQueryContext, Predicate<SchemaElement> needRemovePredicate) {
chatQueryContext
.getMapInfo()
.getDataSetElementMatches()
.values()
.forEach(
schemaElementMatches -> {
schemaElementMatches.removeIf(
schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType type = element.getType();
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
List<SchemaElementMatch> schemaElementMatches = entry.getValue();
schemaElementMatches.removeIf(
schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType type = element.getType();
boolean isEntityOrDatasetOrId =
SchemaElementType.ENTITY.equals(type)
|| SchemaElementType.DATASET.equals(type)
|| SchemaElementType.ID.equals(type);
boolean isEntityOrDatasetOrId =
SchemaElementType.ENTITY.equals(type)
|| SchemaElementType.DATASET.equals(type)
|| SchemaElementType.ID.equals(type);
return !isEntityOrDatasetOrId
&& needRemovePredicate.test(element);
});
});
return !isEntityOrDatasetOrId && needRemovePredicate.test(element);
});
}
}
public static void filterByRules(ChatQueryContext chatQueryContext) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
List<SchemaElementMatch> elementMatches = entry.getValue();
filterByExactMatch(elementMatches);
filterInExactMatch(elementMatches);
}
}
public static List<SchemaElementMatch> filterByExactMatch(List<SchemaElementMatch> matches) {
// Group by detectWord
Map<String, List<SchemaElementMatch>> groupedByDetectWord =
matches.stream().collect(Collectors.groupingBy(SchemaElementMatch::getDetectWord));
List<SchemaElementMatch> result = new ArrayList<>();
for (Map.Entry<String, List<SchemaElementMatch>> entry : groupedByDetectWord.entrySet()) {
List<SchemaElementMatch> group = entry.getValue();
// Filter out objects with similarity=1.0
List<SchemaElementMatch> fullMatches =
group.stream()
.filter(SchemaElementMatch::isFullMatched)
.collect(Collectors.toList());
if (!fullMatches.isEmpty()) {
// If there are objects with similarity=1.0, choose the one with the longest
// detectWord and smallest offset
SchemaElementMatch bestMatch =
fullMatches.stream()
.max(
Comparator.comparing(
(SchemaElementMatch match) ->
match.getDetectWord().length()))
.orElse(null);
if (bestMatch != null) {
result.add(bestMatch);
}
} else {
// If there are no objects with similarity=1.0, keep all objects with similarity<1.0
result.addAll(group);
}
}
return result;
}
public static List<SchemaElementMatch> filterInExactMatch(List<SchemaElementMatch> matches) {
Map<String, List<SchemaElementMatch>> fullMatches =
matches.stream()
.filter(schemaElementMatch -> schemaElementMatch.isFullMatched())
.collect(Collectors.groupingBy(SchemaElementMatch::getDetectWord));
Set<String> keys = new HashSet<>(fullMatches.keySet());
for (String key1 : keys) {
for (String key2 : keys) {
if (!key1.equals(key2) && key1.contains(key2)) {
fullMatches.remove(key2);
}
}
}
List<SchemaElementMatch> notFullMatches =
matches.stream()
.filter(schemaElementMatch -> !schemaElementMatch.isFullMatched())
.collect(Collectors.toList());
List<SchemaElementMatch> mergedMatches = new ArrayList<>();
fullMatches.values().forEach(mergedMatches::addAll);
mergedMatches.addAll(notFullMatches);
return mergedMatches;
}
}

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import lombok.Data;
@@ -67,19 +66,4 @@ public class MapperHelper {
}
return false;
}
/**
* * get similarity
*
* @param detectSegment
* @param matchName
* @return
*/
public double getSimilarity(String detectSegment, String matchName) {
String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase();
String matchNameLower = matchName == null ? null : matchName.toLowerCase();
return 1
- (double) EditDistance.compute(detectSegmentLower, matchNameLower)
/ Math.max(matchName.length(), detectSegment.length());
}
}

View File

@@ -0,0 +1,27 @@
package com.tencent.supersonic.headless.chat.utils;
import com.hankcs.hanlp.algorithm.EditDistance;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Data
@Service
@Slf4j
public class EditDistanceUtils {
/**
* * get similarity
*
* @param detectSegment
* @param matchName
* @return
*/
public static double getSimilarity(String detectSegment, String matchName) {
String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase();
String matchNameLower = matchName == null ? null : matchName.toLowerCase();
return 1
- (double) EditDistance.compute(detectSegmentLower, matchNameLower)
/ Math.max(matchName.length(), detectSegment.length());
}
}