mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
[improvement](chat) Optimize the Schema Mapping rules (#1649)
This commit is contained in:
@@ -51,13 +51,13 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||||
double distance = embeddingRetrieval.getDistance();
|
double similarity = embeddingRetrieval.getSimilarity();
|
||||||
double score = parseContext.getQueryText().length() * (1 - distance);
|
double score = parseContext.getQueryText().length() * similarity;
|
||||||
return PluginRecallResult.builder()
|
return PluginRecallResult.builder()
|
||||||
.plugin(plugin)
|
.plugin(plugin)
|
||||||
.dataSetIds(dataSetList)
|
.dataSetIds(dataSetList)
|
||||||
.score(score)
|
.score(score)
|
||||||
.distance(distance)
|
.distance(similarity)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -73,7 +73,9 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
|||||||
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
|
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||||
embeddingRetrievals =
|
embeddingRetrievals =
|
||||||
embeddingRetrievals.stream()
|
embeddingRetrievals.stream()
|
||||||
.sorted(Comparator.comparingDouble(o -> Math.abs(o.getDistance())))
|
.sorted(
|
||||||
|
Comparator.comparingDouble(
|
||||||
|
o -> Math.abs(o.getSimilarity())))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
embeddingResp.setRetrieval(embeddingRetrievals);
|
embeddingResp.setRetrieval(embeddingRetrievals);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
|||||||
List<Retrieval> retrievals =
|
List<Retrieval> retrievals =
|
||||||
retrieveQueryResults.stream()
|
retrieveQueryResults.stream()
|
||||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
|
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
|
||||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance))
|
.sorted(Comparator.comparingDouble(Retrieval::getSimilarity))
|
||||||
.distinct()
|
.distinct()
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
Set<Long> metricIds =
|
Set<Long> metricIds =
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
List<Retrieval> retrievals =
|
List<Retrieval> retrievals =
|
||||||
result.matches().stream()
|
result.matches().stream()
|
||||||
.map(this::convertToRetrieval)
|
.map(this::convertToRetrieval)
|
||||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
|
.sorted(Comparator.comparingDouble(Retrieval::getSimilarity))
|
||||||
.limit(num)
|
.limit(num)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
@@ -177,7 +177,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
private Retrieval convertToRetrieval(EmbeddingMatch<TextSegment> embeddingMatch) {
|
private Retrieval convertToRetrieval(EmbeddingMatch<TextSegment> embeddingMatch) {
|
||||||
Retrieval retrieval = new Retrieval();
|
Retrieval retrieval = new Retrieval();
|
||||||
TextSegment embedded = embeddingMatch.embedded();
|
TextSegment embedded = embeddingMatch.embedded();
|
||||||
retrieval.setDistance(1 - embeddingMatch.score());
|
retrieval.setSimilarity(embeddingMatch.score());
|
||||||
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
||||||
retrieval.setQuery(embedded.text());
|
retrieval.setQuery(embedded.text());
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ public class Retrieval {
|
|||||||
|
|
||||||
protected String id;
|
protected String id;
|
||||||
|
|
||||||
protected double distance;
|
protected double similarity;
|
||||||
|
|
||||||
protected String query;
|
protected String query;
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ public class Retrieval {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
Retrieval retrieval = (Retrieval) o;
|
Retrieval retrieval = (Retrieval) o;
|
||||||
return Double.compare(retrieval.distance, distance) == 0
|
return Double.compare(retrieval.similarity, similarity) == 0
|
||||||
&& Objects.equal(id, retrieval.id)
|
&& Objects.equal(id, retrieval.id)
|
||||||
&& Objects.equal(query, retrieval.query)
|
&& Objects.equal(query, retrieval.query)
|
||||||
&& Objects.equal(metadata, retrieval.metadata);
|
&& Objects.equal(metadata, retrieval.metadata);
|
||||||
@@ -43,6 +43,6 @@ public class Retrieval {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hashCode(id, distance, query, metadata);
|
return Objects.hashCode(id, similarity, query, metadata);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import lombok.ToString;
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class SchemaElementMatch {
|
public class SchemaElementMatch {
|
||||||
|
|
||||||
SchemaElement element;
|
SchemaElement element;
|
||||||
|
double offset;
|
||||||
double similarity;
|
double similarity;
|
||||||
String detectWord;
|
String detectWord;
|
||||||
String word;
|
String word;
|
||||||
|
|||||||
@@ -12,9 +12,6 @@ import java.util.Map;
|
|||||||
public class EmbeddingResult extends MapResult {
|
public class EmbeddingResult extends MapResult {
|
||||||
|
|
||||||
private String id;
|
private String id;
|
||||||
|
|
||||||
private double distance;
|
|
||||||
|
|
||||||
private Map<String, String> metadata;
|
private Map<String, String> metadata;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -10,16 +10,13 @@ import java.util.List;
|
|||||||
@Data
|
@Data
|
||||||
@ToString
|
@ToString
|
||||||
public class HanlpMapResult extends MapResult {
|
public class HanlpMapResult extends MapResult {
|
||||||
|
|
||||||
private List<String> natures;
|
private List<String> natures;
|
||||||
private int offset = 0;
|
|
||||||
|
|
||||||
private double similarity;
|
public HanlpMapResult(String name, List<String> natures, String detectWord, double similarity) {
|
||||||
|
|
||||||
public HanlpMapResult(String name, List<String> natures, String detectWord) {
|
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.natures = natures;
|
this.natures = natures;
|
||||||
this.detectWord = detectWord;
|
this.detectWord = detectWord;
|
||||||
|
this.similarity = similarity;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -10,16 +10,17 @@ import java.io.Serializable;
|
|||||||
public abstract class MapResult implements Serializable {
|
public abstract class MapResult implements Serializable {
|
||||||
|
|
||||||
protected String name;
|
protected String name;
|
||||||
|
protected int offset;
|
||||||
|
|
||||||
protected int index;
|
|
||||||
protected String detectWord;
|
protected String detectWord;
|
||||||
|
|
||||||
|
protected double similarity;
|
||||||
|
|
||||||
public abstract String getMapKey();
|
public abstract String getMapKey();
|
||||||
|
|
||||||
public Boolean lessSimilar(MapResult otherResult) {
|
public Boolean lessSimilar(MapResult otherResult) {
|
||||||
String mapKey = this.getMapKey();
|
String mapKey = this.getMapKey();
|
||||||
String otherMapKey = otherResult.getMapKey();
|
String otherMapKey = otherResult.getMapKey();
|
||||||
return mapKey.equals(otherMapKey)
|
return mapKey.equals(otherMapKey) && otherResult.similarity < otherResult.similarity;
|
||||||
&& otherResult.getDetectWord().length() < otherResult.getDetectWord().length();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import com.hankcs.hanlp.seg.common.Term;
|
|||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||||
|
import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -62,7 +63,9 @@ public class SearchService {
|
|||||||
.map(
|
.map(
|
||||||
entry -> {
|
entry -> {
|
||||||
String name = entry.getKey().replace("#", " ");
|
String name = entry.getKey().replace("#", " ");
|
||||||
return new HanlpMapResult(name, entry.getValue(), key);
|
double similarity = EditDistanceUtils.getSimilarity(name, key);
|
||||||
|
return new HanlpMapResult(
|
||||||
|
name, entry.getValue(), key, similarity);
|
||||||
})
|
})
|
||||||
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
@@ -109,8 +112,10 @@ public class SearchService {
|
|||||||
.getType(),
|
.getType(),
|
||||||
""))
|
""))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
name = StringUtils.reverse(name);
|
name = StringUtils.reverse(name);
|
||||||
return new HanlpMapResult(name, natures, key);
|
double similarity = EditDistanceUtils.getSimilarity(name, key);
|
||||||
|
return new HanlpMapResult(name, natures, key, similarity);
|
||||||
})
|
})
|
||||||
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|||||||
@@ -270,7 +270,10 @@ public class HanlpHelper {
|
|||||||
if (orig != null) {
|
if (orig != null) {
|
||||||
MapResult addMapResult =
|
MapResult addMapResult =
|
||||||
new HanlpMapResult(
|
new HanlpMapResult(
|
||||||
orig, Arrays.asList(nature), hanlpMapResult.getDetectWord());
|
orig,
|
||||||
|
Arrays.asList(nature),
|
||||||
|
hanlpMapResult.getDetectWord(),
|
||||||
|
hanlpMapResult.getSimilarity());
|
||||||
mapResults.add((T) addMapResult);
|
mapResults.add((T) addMapResult);
|
||||||
isAdd = true;
|
isAdd = true;
|
||||||
}
|
}
|
||||||
@@ -301,7 +304,7 @@ public class HanlpHelper {
|
|||||||
addMapResult.setDetectWord(embeddingResult.getDetectWord());
|
addMapResult.setDetectWord(embeddingResult.getDetectWord());
|
||||||
addMapResult.setId(embeddingResult.getId());
|
addMapResult.setId(embeddingResult.getId());
|
||||||
addMapResult.setMetadata(embeddingResult.getMetadata());
|
addMapResult.setMetadata(embeddingResult.getMetadata());
|
||||||
addMapResult.setDistance(embeddingResult.getDistance());
|
addMapResult.setSimilarity(embeddingResult.getSimilarity());
|
||||||
mapResults.add((T) addMapResult);
|
mapResults.add((T) addMapResult);
|
||||||
isAdd = true;
|
isAdd = true;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
doMap(chatQueryContext);
|
doMap(chatQueryContext);
|
||||||
filter(chatQueryContext);
|
MapFilter.filter(chatQueryContext);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("work error", e);
|
log.error("work error", e);
|
||||||
}
|
}
|
||||||
@@ -147,33 +147,4 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
return matches;
|
return matches;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void filter(ChatQueryContext chatQueryContext) {
|
|
||||||
MapFilter.filterByDataSetId(chatQueryContext);
|
|
||||||
MapFilter.filterByDetectWordLenLessThanOne(chatQueryContext);
|
|
||||||
switch (chatQueryContext.getQueryDataType()) {
|
|
||||||
case TAG:
|
|
||||||
MapFilter.filterByQueryDataType(
|
|
||||||
chatQueryContext, element -> !(element.getIsTag() > 0));
|
|
||||||
break;
|
|
||||||
case METRIC:
|
|
||||||
MapFilter.filterByQueryDataType(
|
|
||||||
chatQueryContext,
|
|
||||||
element -> !SchemaElementType.METRIC.equals(element.getType()));
|
|
||||||
break;
|
|
||||||
case DIMENSION:
|
|
||||||
MapFilter.filterByQueryDataType(
|
|
||||||
chatQueryContext,
|
|
||||||
element -> {
|
|
||||||
boolean isDimensionOrValue =
|
|
||||||
SchemaElementType.DIMENSION.equals(element.getType())
|
|
||||||
|| SchemaElementType.VALUE.equals(element.getType());
|
|
||||||
return !isDimensionOrValue;
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
case ALL:
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
|
|||||||
boolean isDeleted =
|
boolean isDeleted =
|
||||||
existResults.removeIf(
|
existResults.removeIf(
|
||||||
existResult -> {
|
existResult -> {
|
||||||
boolean delete = oneRoundResult.lessSimilar(existResult);
|
boolean delete = existResult.lessSimilar(oneRoundResult);
|
||||||
if (delete) {
|
if (delete) {
|
||||||
log.info("deleted existResult:{}", existResult);
|
log.info("deleted existResult:{}", existResult);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
||||||
|
import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -49,9 +50,8 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
|
|||||||
List<DatabaseMapResult> results = new ArrayList<>();
|
List<DatabaseMapResult> results = new ArrayList<>();
|
||||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||||
String name = entry.getKey();
|
String name = entry.getKey();
|
||||||
if (!name.contains(detectSegment)
|
double similarity = EditDistanceUtils.getSimilarity(detectSegment, name);
|
||||||
|| mapperHelper.getSimilarity(detectSegment, name)
|
if (!name.contains(detectSegment) || similarity < metricDimensionThresholdConfig) {
|
||||||
< metricDimensionThresholdConfig) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Set<SchemaElement> schemaElements = entry.getValue();
|
Set<SchemaElement> schemaElements = entry.getValue();
|
||||||
@@ -68,6 +68,7 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
|
|||||||
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
|
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
|
||||||
databaseMapResult.setDetectWord(detectSegment);
|
databaseMapResult.setDetectWord(detectSegment);
|
||||||
databaseMapResult.setName(schemaElement.getName());
|
databaseMapResult.setName(schemaElement.getName());
|
||||||
|
databaseMapResult.setSimilarity(similarity);
|
||||||
databaseMapResult.setSchemaElement(schemaElement);
|
databaseMapResult.setSchemaElement(schemaElement);
|
||||||
results.add(databaseMapResult);
|
results.add(databaseMapResult);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
.element(schemaElement)
|
.element(schemaElement)
|
||||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||||
.word(matchResult.getName())
|
.word(matchResult.getName())
|
||||||
.similarity(1 - matchResult.getDistance())
|
.similarity(matchResult.getSimilarity())
|
||||||
.detectWord(matchResult.getDetectWord())
|
.detectWord(matchResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
// 3. add to mapInfo
|
// 3. add to mapInfo
|
||||||
|
|||||||
@@ -57,13 +57,14 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
Lists.partition(queryTextsList, embeddingMapperBatch);
|
Lists.partition(queryTextsList, embeddingMapperBatch);
|
||||||
|
|
||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
|
List<EmbeddingResult> oneRoundResults =
|
||||||
|
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
|
||||||
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
}
|
}
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void detectByQueryTextsSub(
|
private List<EmbeddingResult> detectByQueryTextsSub(
|
||||||
Set<EmbeddingResult> results,
|
|
||||||
Set<Long> detectDataSetIds,
|
Set<Long> detectDataSetIds,
|
||||||
List<String> queryTextsSub,
|
List<String> queryTextsSub,
|
||||||
ChatQueryContext chatQueryContext) {
|
ChatQueryContext chatQueryContext) {
|
||||||
@@ -89,7 +90,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
return;
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
// step3. build EmbeddingResults
|
// step3. build EmbeddingResults
|
||||||
List<EmbeddingResult> collect =
|
List<EmbeddingResult> collect =
|
||||||
@@ -103,8 +104,8 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
if (!retrieveQueryResult
|
if (!retrieveQueryResult
|
||||||
.getQuery()
|
.getQuery()
|
||||||
.contains(retrieval.getQuery())) {
|
.contains(retrieval.getQuery())) {
|
||||||
return retrieval.getDistance()
|
return retrieval.getSimilarity()
|
||||||
> 1 - threshold;
|
< threshold;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
@@ -150,11 +151,9 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
|||||||
int embeddingRoundNumber =
|
int embeddingRoundNumber =
|
||||||
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
||||||
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
|
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
|
||||||
List<EmbeddingResult> oneRoundResults =
|
return collect.stream()
|
||||||
collect.stream()
|
.sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity))
|
||||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
.limit(roundNumber)
|
||||||
.limit(roundNumber)
|
.collect(Collectors.toList());
|
||||||
.collect(Collectors.toList());
|
|
||||||
selectResultInOneRound(results, oneRoundResults);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,10 +72,15 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
|
|||||||
hanlpMapResults.stream()
|
hanlpMapResults.stream()
|
||||||
.filter(
|
.filter(
|
||||||
term ->
|
term ->
|
||||||
mapperHelper.getSimilarity(detectSegment, term.getName())
|
term.getSimilarity()
|
||||||
>= getThresholdMatch(
|
>= getThresholdMatch(
|
||||||
term.getNatures(), chatQueryContext))
|
term.getNatures(), chatQueryContext))
|
||||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||||
|
.map(
|
||||||
|
parseResult -> {
|
||||||
|
parseResult.setOffset(offset);
|
||||||
|
return parseResult;
|
||||||
|
})
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -83,18 +88,6 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
|
|||||||
detectSegment,
|
detectSegment,
|
||||||
hanlpMapResults);
|
hanlpMapResults);
|
||||||
|
|
||||||
hanlpMapResults =
|
|
||||||
hanlpMapResults.stream()
|
|
||||||
.map(
|
|
||||||
parseResult -> {
|
|
||||||
parseResult.setOffset(offset);
|
|
||||||
parseResult.setSimilarity(
|
|
||||||
mapperHelper.getSimilarity(
|
|
||||||
detectSegment, parseResult.getName()));
|
|
||||||
return parseResult;
|
|
||||||
})
|
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
|
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
|
||||||
int oneDetectionValueSize =
|
int oneDetectionValueSize =
|
||||||
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
|
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
|||||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||||
|
import com.tencent.supersonic.headless.chat.utils.EditDistanceUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -103,7 +104,6 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
|
|
||||||
private void convertDatabaseMapResultToMapInfo(
|
private void convertDatabaseMapResultToMapInfo(
|
||||||
ChatQueryContext chatQueryContext, List<DatabaseMapResult> mapResults) {
|
ChatQueryContext chatQueryContext, List<DatabaseMapResult> mapResults) {
|
||||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
|
||||||
for (DatabaseMapResult match : mapResults) {
|
for (DatabaseMapResult match : mapResults) {
|
||||||
SchemaElement schemaElement = match.getSchemaElement();
|
SchemaElement schemaElement = match.getSchemaElement();
|
||||||
Set<Long> regElementSet =
|
Set<Long> regElementSet =
|
||||||
@@ -118,7 +118,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.detectWord(match.getDetectWord())
|
.detectWord(match.getDetectWord())
|
||||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||||
.similarity(
|
.similarity(
|
||||||
mapperHelper.getSimilarity(
|
EditDistanceUtils.getSimilarity(
|
||||||
match.getDetectWord(), schemaElement.getName()))
|
match.getDetectWord(), schemaElement.getName()))
|
||||||
.build();
|
.build();
|
||||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||||
|
|||||||
@@ -7,14 +7,46 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class MapFilter {
|
public class MapFilter {
|
||||||
|
|
||||||
|
public static void filter(ChatQueryContext chatQueryContext) {
|
||||||
|
filterByDataSetId(chatQueryContext);
|
||||||
|
filterByDetectWordLenLessThanOne(chatQueryContext);
|
||||||
|
switch (chatQueryContext.getQueryDataType()) {
|
||||||
|
case TAG:
|
||||||
|
filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0));
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
filterByQueryDataType(
|
||||||
|
chatQueryContext,
|
||||||
|
element -> !SchemaElementType.METRIC.equals(element.getType()));
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
filterByQueryDataType(
|
||||||
|
chatQueryContext,
|
||||||
|
element -> {
|
||||||
|
boolean isDimensionOrValue =
|
||||||
|
SchemaElementType.DIMENSION.equals(element.getType())
|
||||||
|
|| SchemaElementType.VALUE.equals(element.getType());
|
||||||
|
return !isDimensionOrValue;
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case ALL:
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
filterByRules(chatQueryContext);
|
||||||
|
}
|
||||||
|
|
||||||
public static void filterByDataSetId(ChatQueryContext chatQueryContext) {
|
public static void filterByDataSetId(ChatQueryContext chatQueryContext) {
|
||||||
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
||||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||||
@@ -44,25 +76,93 @@ public class MapFilter {
|
|||||||
|
|
||||||
public static void filterByQueryDataType(
|
public static void filterByQueryDataType(
|
||||||
ChatQueryContext chatQueryContext, Predicate<SchemaElement> needRemovePredicate) {
|
ChatQueryContext chatQueryContext, Predicate<SchemaElement> needRemovePredicate) {
|
||||||
chatQueryContext
|
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||||
.getMapInfo()
|
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||||
.getDataSetElementMatches()
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||||
.values()
|
List<SchemaElementMatch> schemaElementMatches = entry.getValue();
|
||||||
.forEach(
|
schemaElementMatches.removeIf(
|
||||||
schemaElementMatches -> {
|
schemaElementMatch -> {
|
||||||
schemaElementMatches.removeIf(
|
SchemaElement element = schemaElementMatch.getElement();
|
||||||
schemaElementMatch -> {
|
SchemaElementType type = element.getType();
|
||||||
SchemaElement element = schemaElementMatch.getElement();
|
|
||||||
SchemaElementType type = element.getType();
|
|
||||||
|
|
||||||
boolean isEntityOrDatasetOrId =
|
boolean isEntityOrDatasetOrId =
|
||||||
SchemaElementType.ENTITY.equals(type)
|
SchemaElementType.ENTITY.equals(type)
|
||||||
|| SchemaElementType.DATASET.equals(type)
|
|| SchemaElementType.DATASET.equals(type)
|
||||||
|| SchemaElementType.ID.equals(type);
|
|| SchemaElementType.ID.equals(type);
|
||||||
|
|
||||||
return !isEntityOrDatasetOrId
|
return !isEntityOrDatasetOrId && needRemovePredicate.test(element);
|
||||||
&& needRemovePredicate.test(element);
|
});
|
||||||
});
|
}
|
||||||
});
|
}
|
||||||
|
|
||||||
|
public static void filterByRules(ChatQueryContext chatQueryContext) {
|
||||||
|
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||||
|
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||||
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||||
|
List<SchemaElementMatch> elementMatches = entry.getValue();
|
||||||
|
filterByExactMatch(elementMatches);
|
||||||
|
filterInExactMatch(elementMatches);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<SchemaElementMatch> filterByExactMatch(List<SchemaElementMatch> matches) {
|
||||||
|
// Group by detectWord
|
||||||
|
Map<String, List<SchemaElementMatch>> groupedByDetectWord =
|
||||||
|
matches.stream().collect(Collectors.groupingBy(SchemaElementMatch::getDetectWord));
|
||||||
|
|
||||||
|
List<SchemaElementMatch> result = new ArrayList<>();
|
||||||
|
|
||||||
|
for (Map.Entry<String, List<SchemaElementMatch>> entry : groupedByDetectWord.entrySet()) {
|
||||||
|
List<SchemaElementMatch> group = entry.getValue();
|
||||||
|
|
||||||
|
// Filter out objects with similarity=1.0
|
||||||
|
List<SchemaElementMatch> fullMatches =
|
||||||
|
group.stream()
|
||||||
|
.filter(SchemaElementMatch::isFullMatched)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (!fullMatches.isEmpty()) {
|
||||||
|
// If there are objects with similarity=1.0, choose the one with the longest
|
||||||
|
// detectWord and smallest offset
|
||||||
|
SchemaElementMatch bestMatch =
|
||||||
|
fullMatches.stream()
|
||||||
|
.max(
|
||||||
|
Comparator.comparing(
|
||||||
|
(SchemaElementMatch match) ->
|
||||||
|
match.getDetectWord().length()))
|
||||||
|
.orElse(null);
|
||||||
|
if (bestMatch != null) {
|
||||||
|
result.add(bestMatch);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If there are no objects with similarity=1.0, keep all objects with similarity<1.0
|
||||||
|
result.addAll(group);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<SchemaElementMatch> filterInExactMatch(List<SchemaElementMatch> matches) {
|
||||||
|
Map<String, List<SchemaElementMatch>> fullMatches =
|
||||||
|
matches.stream()
|
||||||
|
.filter(schemaElementMatch -> schemaElementMatch.isFullMatched())
|
||||||
|
.collect(Collectors.groupingBy(SchemaElementMatch::getDetectWord));
|
||||||
|
Set<String> keys = new HashSet<>(fullMatches.keySet());
|
||||||
|
for (String key1 : keys) {
|
||||||
|
for (String key2 : keys) {
|
||||||
|
if (!key1.equals(key2) && key1.contains(key2)) {
|
||||||
|
fullMatches.remove(key2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
List<SchemaElementMatch> notFullMatches =
|
||||||
|
matches.stream()
|
||||||
|
.filter(schemaElementMatch -> !schemaElementMatch.isFullMatched())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
List<SchemaElementMatch> mergedMatches = new ArrayList<>();
|
||||||
|
fullMatches.values().forEach(mergedMatches::addAll);
|
||||||
|
mergedMatches.addAll(notFullMatches);
|
||||||
|
return mergedMatches;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -67,19 +66,4 @@ public class MapperHelper {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* * get similarity
|
|
||||||
*
|
|
||||||
* @param detectSegment
|
|
||||||
* @param matchName
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public double getSimilarity(String detectSegment, String matchName) {
|
|
||||||
String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase();
|
|
||||||
String matchNameLower = matchName == null ? null : matchName.toLowerCase();
|
|
||||||
return 1
|
|
||||||
- (double) EditDistance.compute(detectSegmentLower, matchNameLower)
|
|
||||||
/ Math.max(matchName.length(), detectSegment.length());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.utils;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class EditDistanceUtils {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* * get similarity
|
||||||
|
*
|
||||||
|
* @param detectSegment
|
||||||
|
* @param matchName
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static double getSimilarity(String detectSegment, String matchName) {
|
||||||
|
String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase();
|
||||||
|
String matchNameLower = matchName == null ? null : matchName.toLowerCase();
|
||||||
|
return 1
|
||||||
|
- (double) EditDistance.compute(detectSegmentLower, matchNameLower)
|
||||||
|
/ Math.max(matchName.length(), detectSegment.length());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user