mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
[improvement][headless-chat]Optimize HeuristicDataSetResolver to prioritize max similarity of dataset and metric.#1690
This commit is contained in:
@@ -1,9 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@Builder
|
||||||
public class DataSetMatchResult {
|
public class DataSetMatchResult {
|
||||||
private Integer count = 0;
|
private double maxMetricSimilarity;
|
||||||
private double maxSimilarity;
|
private double maxDatesetSimilarity;
|
||||||
|
private double totalSimilarity;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,157 +4,92 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Map.Entry;
|
import java.util.Map.Entry;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* HeuristicDataSetResolver select ONE most suitable data set out of matched data sets. The
|
||||||
|
* selection is based on similarity comparison rule and the priority is like: 1.
|
||||||
|
* maxSimilarity(matched dataset) 2. maxSimilarity(all matched metrics) 3. totalSimilarity(all
|
||||||
|
* matched elements)
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class HeuristicDataSetResolver implements DataSetResolver {
|
public class HeuristicDataSetResolver implements DataSetResolver {
|
||||||
|
|
||||||
protected static Long selectDataSetBySchemaElementMatchScore(
|
|
||||||
Map<Long, SemanticQuery> dataSetQueryModes, SchemaMapInfo schemaMap) {
|
|
||||||
// dataSet count priority
|
|
||||||
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
|
|
||||||
if (Objects.nonNull(dataSetIdByDataSetCount)) {
|
|
||||||
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
|
|
||||||
return dataSetIdByDataSetCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<Long, DataSetMatchResult> dataSetTypeMap = getDataSetTypeMap(schemaMap);
|
|
||||||
if (dataSetTypeMap.size() == 1) {
|
|
||||||
Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey();
|
|
||||||
if (dataSetQueryModes.containsKey(dataSetSelect)) {
|
|
||||||
log.info("selectDataSet with only one DataSet [{}]", dataSetSelect);
|
|
||||||
return dataSetSelect;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Entry<Long, DataSetMatchResult> maxDataSet =
|
|
||||||
dataSetTypeMap.entrySet().stream()
|
|
||||||
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
|
|
||||||
.sorted(
|
|
||||||
(o1, o2) -> {
|
|
||||||
int difference =
|
|
||||||
o2.getValue().getCount() - o1.getValue().getCount();
|
|
||||||
if (difference == 0) {
|
|
||||||
return (int)
|
|
||||||
((o2.getValue().getMaxSimilarity()
|
|
||||||
- o1.getValue()
|
|
||||||
.getMaxSimilarity())
|
|
||||||
* 100);
|
|
||||||
}
|
|
||||||
return difference;
|
|
||||||
})
|
|
||||||
.findFirst()
|
|
||||||
.orElse(null);
|
|
||||||
if (maxDataSet != null) {
|
|
||||||
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
|
|
||||||
return maxDataSet.getKey();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
|
|
||||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
|
||||||
schemaMap.getDataSetElementMatches();
|
|
||||||
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets
|
|
||||||
// 0.5 point
|
|
||||||
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
|
|
||||||
if (Objects.nonNull(dataSetElementMatches)) {
|
|
||||||
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch :
|
|
||||||
dataSetElementMatches.entrySet()) {
|
|
||||||
Long dataSetId = dataSetElementMatch.getKey();
|
|
||||||
List<Double> dataSetMatchesScore =
|
|
||||||
dataSetElementMatch.getValue().stream()
|
|
||||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
|
||||||
.filter(
|
|
||||||
elementMatch ->
|
|
||||||
SchemaElementType.DATASET.equals(
|
|
||||||
elementMatch.getElement().getType()))
|
|
||||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0)
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
|
|
||||||
// get sum of dataSet match score
|
|
||||||
double score =
|
|
||||||
dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
|
||||||
dataSetIdToDataSetScore.put(dataSetId, score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Entry<Long, Double> maxDataSetScore =
|
|
||||||
dataSetIdToDataSetScore.entrySet().stream()
|
|
||||||
.max(Comparator.comparingDouble(Entry::getValue))
|
|
||||||
.orElse(null);
|
|
||||||
log.info(
|
|
||||||
"maxDataSetCount:{},dataSetIdToDataSetCount:{}",
|
|
||||||
maxDataSetScore,
|
|
||||||
dataSetIdToDataSetScore);
|
|
||||||
if (Objects.nonNull(maxDataSetScore)) {
|
|
||||||
return maxDataSetScore.getKey();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
|
|
||||||
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
|
|
||||||
for (Entry<Long, List<SchemaElementMatch>> entry :
|
|
||||||
schemaMap.getDataSetElementMatches().entrySet()) {
|
|
||||||
List<SchemaElementMatch> schemaElementMatches =
|
|
||||||
schemaMap.getMatchedElements(entry.getKey());
|
|
||||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
|
||||||
if (!dataSetCount.containsKey(entry.getKey())) {
|
|
||||||
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
|
|
||||||
}
|
|
||||||
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
|
|
||||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
|
||||||
schemaElementMatches.stream()
|
|
||||||
.forEach(
|
|
||||||
schemaElementMatch ->
|
|
||||||
schemaElementTypes.add(
|
|
||||||
schemaElementMatch.getElement().getType()));
|
|
||||||
SchemaElementMatch schemaElementMatchMax =
|
|
||||||
schemaElementMatches.stream()
|
|
||||||
.sorted(
|
|
||||||
(o1, o2) ->
|
|
||||||
((int)
|
|
||||||
((o2.getSimilarity() - o1.getSimilarity())
|
|
||||||
* 100)))
|
|
||||||
.findFirst()
|
|
||||||
.orElse(null);
|
|
||||||
if (schemaElementMatchMax != null) {
|
|
||||||
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
|
||||||
}
|
|
||||||
dataSetMatchResult.setCount(schemaElementTypes.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dataSetCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
|
public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
|
||||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||||
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
||||||
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
|
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
|
||||||
matchedDataSets.retainAll(agentDataSetIds);
|
matchedDataSets.retainAll(agentDataSetIds);
|
||||||
}
|
}
|
||||||
Map<Long, SemanticQuery> dataSetQueryModes = new HashMap<>();
|
if (matchedDataSets.size() == 1) {
|
||||||
for (Long dataSetIds : matchedDataSets) {
|
return matchedDataSets.stream().findFirst().get();
|
||||||
dataSetQueryModes.put(dataSetIds, null);
|
|
||||||
}
|
}
|
||||||
if (dataSetQueryModes.size() == 1) {
|
return selectDataSetByMatchSimilarity(mapInfo);
|
||||||
return dataSetQueryModes.keySet().stream().findFirst().get();
|
}
|
||||||
|
|
||||||
|
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
|
||||||
|
Map<Long, DataSetMatchResult> dataSetMatchRet = getDataSetMatchResult(schemaMap);
|
||||||
|
Entry<Long, DataSetMatchResult> selectedDataset =
|
||||||
|
dataSetMatchRet.entrySet().stream()
|
||||||
|
.sorted(
|
||||||
|
(o1, o2) -> {
|
||||||
|
double difference =
|
||||||
|
o1.getValue().getMaxDatesetSimilarity()
|
||||||
|
- o2.getValue().getMaxDatesetSimilarity();
|
||||||
|
if (difference == 0) {
|
||||||
|
difference =
|
||||||
|
o1.getValue().getMaxMetricSimilarity()
|
||||||
|
- o2.getValue().getMaxMetricSimilarity();
|
||||||
|
if (difference == 0) {
|
||||||
|
difference =
|
||||||
|
o1.getValue().getTotalSimilarity()
|
||||||
|
- o2.getValue().getTotalSimilarity();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return difference >= 0 ? -1 : 1;
|
||||||
|
})
|
||||||
|
.findFirst()
|
||||||
|
.orElse(null);
|
||||||
|
if (selectedDataset != null) {
|
||||||
|
log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey());
|
||||||
|
return selectedDataset.getKey();
|
||||||
}
|
}
|
||||||
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo schemaMap) {
|
||||||
|
Map<Long, DataSetMatchResult> dateSetMatchRet = new HashMap<>();
|
||||||
|
for (Entry<Long, List<SchemaElementMatch>> entry :
|
||||||
|
schemaMap.getDataSetElementMatches().entrySet()) {
|
||||||
|
double maxMetricSimilarity = 0;
|
||||||
|
double maxDatasetSimilarity = 0;
|
||||||
|
double totalSimilarity = 0;
|
||||||
|
for (SchemaElementMatch match : entry.getValue()) {
|
||||||
|
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
|
||||||
|
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
|
||||||
|
}
|
||||||
|
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
|
||||||
|
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
|
||||||
|
}
|
||||||
|
totalSimilarity += match.getSimilarity();
|
||||||
|
}
|
||||||
|
dateSetMatchRet.put(
|
||||||
|
entry.getKey(),
|
||||||
|
DataSetMatchResult.builder()
|
||||||
|
.maxMetricSimilarity(maxMetricSimilarity)
|
||||||
|
.maxDatesetSimilarity(maxDatasetSimilarity)
|
||||||
|
.totalSimilarity(totalSimilarity)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
|
||||||
|
return dateSetMatchRet;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,178 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.parser;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.google.common.collect.Sets;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
|
||||||
|
import com.tencent.supersonic.headless.chat.parser.llm.HeuristicDataSetResolver;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
public class HeuristicDataSetResolverTest {
|
||||||
|
|
||||||
|
private DataSetResolver resolver = new HeuristicDataSetResolver();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMaxDatasetSimilarity() {
|
||||||
|
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
|
||||||
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
|
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
|
||||||
|
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||||
|
List<SchemaElementMatch> matches = Lists.newArrayList();
|
||||||
|
matches.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(1L)
|
||||||
|
.name("超音数")
|
||||||
|
.type(SchemaElementType.DATASET)
|
||||||
|
.build())
|
||||||
|
.similarity(1)
|
||||||
|
.build());
|
||||||
|
matches.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(1L)
|
||||||
|
.name("访问次数")
|
||||||
|
.type(SchemaElementType.METRIC)
|
||||||
|
.build())
|
||||||
|
.similarity(0.5)
|
||||||
|
.build());
|
||||||
|
dataSet2Matches.put(1L, matches);
|
||||||
|
|
||||||
|
List<SchemaElementMatch> matches2 = Lists.newArrayList();
|
||||||
|
matches2.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(2L)
|
||||||
|
.name("访问用户数")
|
||||||
|
.type(SchemaElementType.METRIC)
|
||||||
|
.build())
|
||||||
|
.similarity(1)
|
||||||
|
.build());
|
||||||
|
matches2.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(2L)
|
||||||
|
.name("用户")
|
||||||
|
.type(SchemaElementType.DIMENSION)
|
||||||
|
.build())
|
||||||
|
.similarity(1)
|
||||||
|
.build());
|
||||||
|
dataSet2Matches.put(2L, matches2);
|
||||||
|
|
||||||
|
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
|
||||||
|
assert resolvedDataset == 1L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMaxMetricSimilarity() {
|
||||||
|
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
|
||||||
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
|
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
|
||||||
|
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||||
|
List<SchemaElementMatch> matches = Lists.newArrayList();
|
||||||
|
matches.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(1L)
|
||||||
|
.name("访问次数")
|
||||||
|
.type(SchemaElementType.METRIC)
|
||||||
|
.build())
|
||||||
|
.similarity(1)
|
||||||
|
.build());
|
||||||
|
dataSet2Matches.put(1L, matches);
|
||||||
|
|
||||||
|
List<SchemaElementMatch> matches2 = Lists.newArrayList();
|
||||||
|
matches2.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(2L)
|
||||||
|
.name("访问用户数")
|
||||||
|
.type(SchemaElementType.METRIC)
|
||||||
|
.build())
|
||||||
|
.similarity(0.6)
|
||||||
|
.build());
|
||||||
|
matches2.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(2L)
|
||||||
|
.name("用户")
|
||||||
|
.type(SchemaElementType.DIMENSION)
|
||||||
|
.build())
|
||||||
|
.similarity(1)
|
||||||
|
.build());
|
||||||
|
dataSet2Matches.put(2L, matches2);
|
||||||
|
|
||||||
|
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
|
||||||
|
assert resolvedDataset == 1L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTotalSimilarity() {
|
||||||
|
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
|
||||||
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
|
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
|
||||||
|
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||||
|
List<SchemaElementMatch> matches = Lists.newArrayList();
|
||||||
|
matches.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(1L)
|
||||||
|
.name("访问次数")
|
||||||
|
.type(SchemaElementType.METRIC)
|
||||||
|
.build())
|
||||||
|
.similarity(0.8)
|
||||||
|
.build());
|
||||||
|
matches.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(1L)
|
||||||
|
.name("部门")
|
||||||
|
.type(SchemaElementType.METRIC)
|
||||||
|
.build())
|
||||||
|
.similarity(0.7)
|
||||||
|
.build());
|
||||||
|
dataSet2Matches.put(1L, matches);
|
||||||
|
|
||||||
|
List<SchemaElementMatch> matches2 = Lists.newArrayList();
|
||||||
|
matches2.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(2L)
|
||||||
|
.name("访问用户数")
|
||||||
|
.type(SchemaElementType.METRIC)
|
||||||
|
.build())
|
||||||
|
.similarity(0.8)
|
||||||
|
.build());
|
||||||
|
matches2.add(
|
||||||
|
SchemaElementMatch.builder()
|
||||||
|
.element(
|
||||||
|
SchemaElement.builder()
|
||||||
|
.dataSetId(2L)
|
||||||
|
.name("用户")
|
||||||
|
.type(SchemaElementType.DIMENSION)
|
||||||
|
.build())
|
||||||
|
.similarity(1)
|
||||||
|
.build());
|
||||||
|
dataSet2Matches.put(2L, matches2);
|
||||||
|
|
||||||
|
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
|
||||||
|
assert resolvedDataset == 2L;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user