mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][headless-chat]Optimize HeuristicDataSetResolver to prioritize max similarity of dataset and metric.#1690
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class DataSetMatchResult {
|
||||
private Integer count = 0;
|
||||
private double maxSimilarity;
|
||||
private double maxMetricSimilarity;
|
||||
private double maxDatesetSimilarity;
|
||||
private double totalSimilarity;
|
||||
}
|
||||
|
||||
@@ -4,157 +4,92 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* HeuristicDataSetResolver select ONE most suitable data set out of matched data sets. The
|
||||
* selection is based on similarity comparison rule and the priority is like: 1.
|
||||
* maxSimilarity(matched dataset) 2. maxSimilarity(all matched metrics) 3. totalSimilarity(all
|
||||
* matched elements)
|
||||
*/
|
||||
@Slf4j
|
||||
public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
|
||||
protected static Long selectDataSetBySchemaElementMatchScore(
|
||||
Map<Long, SemanticQuery> dataSetQueryModes, SchemaMapInfo schemaMap) {
|
||||
// dataSet count priority
|
||||
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
|
||||
if (Objects.nonNull(dataSetIdByDataSetCount)) {
|
||||
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
|
||||
return dataSetIdByDataSetCount;
|
||||
}
|
||||
|
||||
Map<Long, DataSetMatchResult> dataSetTypeMap = getDataSetTypeMap(schemaMap);
|
||||
if (dataSetTypeMap.size() == 1) {
|
||||
Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey();
|
||||
if (dataSetQueryModes.containsKey(dataSetSelect)) {
|
||||
log.info("selectDataSet with only one DataSet [{}]", dataSetSelect);
|
||||
return dataSetSelect;
|
||||
}
|
||||
} else {
|
||||
Entry<Long, DataSetMatchResult> maxDataSet =
|
||||
dataSetTypeMap.entrySet().stream()
|
||||
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
|
||||
.sorted(
|
||||
(o1, o2) -> {
|
||||
int difference =
|
||||
o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int)
|
||||
((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue()
|
||||
.getMaxSimilarity())
|
||||
* 100);
|
||||
}
|
||||
return difference;
|
||||
})
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
if (maxDataSet != null) {
|
||||
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
|
||||
return maxDataSet.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||
schemaMap.getDataSetElementMatches();
|
||||
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets
|
||||
// 0.5 point
|
||||
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
|
||||
if (Objects.nonNull(dataSetElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch :
|
||||
dataSetElementMatches.entrySet()) {
|
||||
Long dataSetId = dataSetElementMatch.getKey();
|
||||
List<Double> dataSetMatchesScore =
|
||||
dataSetElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(
|
||||
elementMatch ->
|
||||
SchemaElementType.DATASET.equals(
|
||||
elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
|
||||
// get sum of dataSet match score
|
||||
double score =
|
||||
dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
dataSetIdToDataSetScore.put(dataSetId, score);
|
||||
}
|
||||
}
|
||||
Entry<Long, Double> maxDataSetScore =
|
||||
dataSetIdToDataSetScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(Entry::getValue))
|
||||
.orElse(null);
|
||||
log.info(
|
||||
"maxDataSetCount:{},dataSetIdToDataSetCount:{}",
|
||||
maxDataSetScore,
|
||||
dataSetIdToDataSetScore);
|
||||
if (Objects.nonNull(maxDataSetScore)) {
|
||||
return maxDataSetScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
|
||||
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry :
|
||||
schemaMap.getDataSetElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches =
|
||||
schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!dataSetCount.containsKey(entry.getKey())) {
|
||||
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
|
||||
}
|
||||
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(
|
||||
schemaElementMatch ->
|
||||
schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax =
|
||||
schemaElementMatches.stream()
|
||||
.sorted(
|
||||
(o1, o2) ->
|
||||
((int)
|
||||
((o2.getSimilarity() - o1.getSimilarity())
|
||||
* 100)))
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
dataSetMatchResult.setCount(schemaElementTypes.size());
|
||||
}
|
||||
}
|
||||
return dataSetCount;
|
||||
}
|
||||
|
||||
public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
|
||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
||||
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
|
||||
matchedDataSets.retainAll(agentDataSetIds);
|
||||
}
|
||||
Map<Long, SemanticQuery> dataSetQueryModes = new HashMap<>();
|
||||
for (Long dataSetIds : matchedDataSets) {
|
||||
dataSetQueryModes.put(dataSetIds, null);
|
||||
if (matchedDataSets.size() == 1) {
|
||||
return matchedDataSets.stream().findFirst().get();
|
||||
}
|
||||
if (dataSetQueryModes.size() == 1) {
|
||||
return dataSetQueryModes.keySet().stream().findFirst().get();
|
||||
return selectDataSetByMatchSimilarity(mapInfo);
|
||||
}
|
||||
|
||||
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
|
||||
Map<Long, DataSetMatchResult> dataSetMatchRet = getDataSetMatchResult(schemaMap);
|
||||
Entry<Long, DataSetMatchResult> selectedDataset =
|
||||
dataSetMatchRet.entrySet().stream()
|
||||
.sorted(
|
||||
(o1, o2) -> {
|
||||
double difference =
|
||||
o1.getValue().getMaxDatesetSimilarity()
|
||||
- o2.getValue().getMaxDatesetSimilarity();
|
||||
if (difference == 0) {
|
||||
difference =
|
||||
o1.getValue().getMaxMetricSimilarity()
|
||||
- o2.getValue().getMaxMetricSimilarity();
|
||||
if (difference == 0) {
|
||||
difference =
|
||||
o1.getValue().getTotalSimilarity()
|
||||
- o2.getValue().getTotalSimilarity();
|
||||
}
|
||||
}
|
||||
return difference >= 0 ? -1 : 1;
|
||||
})
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
if (selectedDataset != null) {
|
||||
log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey());
|
||||
return selectedDataset.getKey();
|
||||
}
|
||||
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo schemaMap) {
|
||||
Map<Long, DataSetMatchResult> dateSetMatchRet = new HashMap<>();
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry :
|
||||
schemaMap.getDataSetElementMatches().entrySet()) {
|
||||
double maxMetricSimilarity = 0;
|
||||
double maxDatasetSimilarity = 0;
|
||||
double totalSimilarity = 0;
|
||||
for (SchemaElementMatch match : entry.getValue()) {
|
||||
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
|
||||
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
|
||||
}
|
||||
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
|
||||
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
|
||||
}
|
||||
totalSimilarity += match.getSimilarity();
|
||||
}
|
||||
dateSetMatchRet.put(
|
||||
entry.getKey(),
|
||||
DataSetMatchResult.builder()
|
||||
.maxMetricSimilarity(maxMetricSimilarity)
|
||||
.maxDatesetSimilarity(maxDatasetSimilarity)
|
||||
.totalSimilarity(totalSimilarity)
|
||||
.build());
|
||||
}
|
||||
|
||||
return dateSetMatchRet;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.HeuristicDataSetResolver;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class HeuristicDataSetResolverTest {
|
||||
|
||||
private DataSetResolver resolver = new HeuristicDataSetResolver();
|
||||
|
||||
@Test
|
||||
public void testMaxDatasetSimilarity() {
|
||||
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||
List<SchemaElementMatch> matches = Lists.newArrayList();
|
||||
matches.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("超音数")
|
||||
.type(SchemaElementType.DATASET)
|
||||
.build())
|
||||
.similarity(1)
|
||||
.build());
|
||||
matches.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("访问次数")
|
||||
.type(SchemaElementType.METRIC)
|
||||
.build())
|
||||
.similarity(0.5)
|
||||
.build());
|
||||
dataSet2Matches.put(1L, matches);
|
||||
|
||||
List<SchemaElementMatch> matches2 = Lists.newArrayList();
|
||||
matches2.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(2L)
|
||||
.name("访问用户数")
|
||||
.type(SchemaElementType.METRIC)
|
||||
.build())
|
||||
.similarity(1)
|
||||
.build());
|
||||
matches2.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(2L)
|
||||
.name("用户")
|
||||
.type(SchemaElementType.DIMENSION)
|
||||
.build())
|
||||
.similarity(1)
|
||||
.build());
|
||||
dataSet2Matches.put(2L, matches2);
|
||||
|
||||
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
|
||||
assert resolvedDataset == 1L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaxMetricSimilarity() {
|
||||
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||
List<SchemaElementMatch> matches = Lists.newArrayList();
|
||||
matches.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("访问次数")
|
||||
.type(SchemaElementType.METRIC)
|
||||
.build())
|
||||
.similarity(1)
|
||||
.build());
|
||||
dataSet2Matches.put(1L, matches);
|
||||
|
||||
List<SchemaElementMatch> matches2 = Lists.newArrayList();
|
||||
matches2.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(2L)
|
||||
.name("访问用户数")
|
||||
.type(SchemaElementType.METRIC)
|
||||
.build())
|
||||
.similarity(0.6)
|
||||
.build());
|
||||
matches2.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(2L)
|
||||
.name("用户")
|
||||
.type(SchemaElementType.DIMENSION)
|
||||
.build())
|
||||
.similarity(1)
|
||||
.build());
|
||||
dataSet2Matches.put(2L, matches2);
|
||||
|
||||
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
|
||||
assert resolvedDataset == 1L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTotalSimilarity() {
|
||||
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||
List<SchemaElementMatch> matches = Lists.newArrayList();
|
||||
matches.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("访问次数")
|
||||
.type(SchemaElementType.METRIC)
|
||||
.build())
|
||||
.similarity(0.8)
|
||||
.build());
|
||||
matches.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("部门")
|
||||
.type(SchemaElementType.METRIC)
|
||||
.build())
|
||||
.similarity(0.7)
|
||||
.build());
|
||||
dataSet2Matches.put(1L, matches);
|
||||
|
||||
List<SchemaElementMatch> matches2 = Lists.newArrayList();
|
||||
matches2.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(2L)
|
||||
.name("访问用户数")
|
||||
.type(SchemaElementType.METRIC)
|
||||
.build())
|
||||
.similarity(0.8)
|
||||
.build());
|
||||
matches2.add(
|
||||
SchemaElementMatch.builder()
|
||||
.element(
|
||||
SchemaElement.builder()
|
||||
.dataSetId(2L)
|
||||
.name("用户")
|
||||
.type(SchemaElementType.DIMENSION)
|
||||
.build())
|
||||
.similarity(1)
|
||||
.build());
|
||||
dataSet2Matches.put(2L, matches2);
|
||||
|
||||
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
|
||||
assert resolvedDataset == 2L;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user