[improvement][headless-chat]Optimize HeuristicDataSetResolver to prioritize max similarity of dataset and metric.#1690

This commit is contained in:
jerryjzhang
2024-09-20 14:25:51 +08:00
parent 5ba401addf
commit 2c7758d0ca
3 changed files with 250 additions and 134 deletions

View File

@@ -1,9 +1,12 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import lombok.Builder;
import lombok.Data;
@Data
@Builder
public class DataSetMatchResult {
private Integer count = 0;
private double maxSimilarity;
private double maxMetricSimilarity;
private double maxDatesetSimilarity;
private double totalSimilarity;
}

View File

@@ -4,157 +4,92 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* HeuristicDataSetResolver select ONE most suitable data set out of matched data sets. The
* selection is based on similarity comparison rule and the priority is like: 1.
* maxSimilarity(matched dataset) 2. maxSimilarity(all matched metrics) 3. totalSimilarity(all
* matched elements)
*/
@Slf4j
public class HeuristicDataSetResolver implements DataSetResolver {
protected static Long selectDataSetBySchemaElementMatchScore(
Map<Long, SemanticQuery> dataSetQueryModes, SchemaMapInfo schemaMap) {
// dataSet count priority
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
if (Objects.nonNull(dataSetIdByDataSetCount)) {
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
return dataSetIdByDataSetCount;
}
Map<Long, DataSetMatchResult> dataSetTypeMap = getDataSetTypeMap(schemaMap);
if (dataSetTypeMap.size() == 1) {
Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey();
if (dataSetQueryModes.containsKey(dataSetSelect)) {
log.info("selectDataSet with only one DataSet [{}]", dataSetSelect);
return dataSetSelect;
}
} else {
Entry<Long, DataSetMatchResult> maxDataSet =
dataSetTypeMap.entrySet().stream()
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
.sorted(
(o1, o2) -> {
int difference =
o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int)
((o2.getValue().getMaxSimilarity()
- o1.getValue()
.getMaxSimilarity())
* 100);
}
return difference;
})
.findFirst()
.orElse(null);
if (maxDataSet != null) {
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
return maxDataSet.getKey();
}
}
return null;
}
private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
schemaMap.getDataSetElementMatches();
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets
// 0.5 point
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
if (Objects.nonNull(dataSetElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch :
dataSetElementMatches.entrySet()) {
Long dataSetId = dataSetElementMatch.getKey();
List<Double> dataSetMatchesScore =
dataSetElementMatch.getValue().stream()
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
.filter(
elementMatch ->
SchemaElementType.DATASET.equals(
elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0)
.collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
// get sum of dataSet match score
double score =
dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
dataSetIdToDataSetScore.put(dataSetId, score);
}
}
Entry<Long, Double> maxDataSetScore =
dataSetIdToDataSetScore.entrySet().stream()
.max(Comparator.comparingDouble(Entry::getValue))
.orElse(null);
log.info(
"maxDataSetCount:{},dataSetIdToDataSetCount:{}",
maxDataSetScore,
dataSetIdToDataSetScore);
if (Objects.nonNull(maxDataSetScore)) {
return maxDataSetScore.getKey();
}
}
return null;
}
public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
for (Entry<Long, List<SchemaElementMatch>> entry :
schemaMap.getDataSetElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches =
schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!dataSetCount.containsKey(entry.getKey())) {
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
}
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(
schemaElementMatch ->
schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax =
schemaElementMatches.stream()
.sorted(
(o1, o2) ->
((int)
((o2.getSimilarity() - o1.getSimilarity())
* 100)))
.findFirst()
.orElse(null);
if (schemaElementMatchMax != null) {
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
dataSetMatchResult.setCount(schemaElementTypes.size());
}
}
return dataSetCount;
}
public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
matchedDataSets.retainAll(agentDataSetIds);
}
Map<Long, SemanticQuery> dataSetQueryModes = new HashMap<>();
for (Long dataSetIds : matchedDataSets) {
dataSetQueryModes.put(dataSetIds, null);
if (matchedDataSets.size() == 1) {
return matchedDataSets.stream().findFirst().get();
}
if (dataSetQueryModes.size() == 1) {
return dataSetQueryModes.keySet().stream().findFirst().get();
return selectDataSetByMatchSimilarity(mapInfo);
}
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dataSetMatchRet = getDataSetMatchResult(schemaMap);
Entry<Long, DataSetMatchResult> selectedDataset =
dataSetMatchRet.entrySet().stream()
.sorted(
(o1, o2) -> {
double difference =
o1.getValue().getMaxDatesetSimilarity()
- o2.getValue().getMaxDatesetSimilarity();
if (difference == 0) {
difference =
o1.getValue().getMaxMetricSimilarity()
- o2.getValue().getMaxMetricSimilarity();
if (difference == 0) {
difference =
o1.getValue().getTotalSimilarity()
- o2.getValue().getTotalSimilarity();
}
}
return difference >= 0 ? -1 : 1;
})
.findFirst()
.orElse(null);
if (selectedDataset != null) {
log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey());
return selectedDataset.getKey();
}
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
return null;
}
protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dateSetMatchRet = new HashMap<>();
for (Entry<Long, List<SchemaElementMatch>> entry :
schemaMap.getDataSetElementMatches().entrySet()) {
double maxMetricSimilarity = 0;
double maxDatasetSimilarity = 0;
double totalSimilarity = 0;
for (SchemaElementMatch match : entry.getValue()) {
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
}
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
}
totalSimilarity += match.getSimilarity();
}
dateSetMatchRet.put(
entry.getKey(),
DataSetMatchResult.builder()
.maxMetricSimilarity(maxMetricSimilarity)
.maxDatesetSimilarity(maxDatasetSimilarity)
.totalSimilarity(totalSimilarity)
.build());
}
return dateSetMatchRet;
}
}

View File

@@ -0,0 +1,178 @@
package com.tencent.supersonic.headless.chat.parser;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
import com.tencent.supersonic.headless.chat.parser.llm.HeuristicDataSetResolver;
import org.junit.Test;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class HeuristicDataSetResolverTest {
private DataSetResolver resolver = new HeuristicDataSetResolver();
@Test
public void testMaxDatasetSimilarity() {
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
ChatQueryContext chatQueryContext = new ChatQueryContext();
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
List<SchemaElementMatch> matches = Lists.newArrayList();
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("超音数")
.type(SchemaElementType.DATASET)
.build())
.similarity(1)
.build());
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("访问次数")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.5)
.build());
dataSet2Matches.put(1L, matches);
List<SchemaElementMatch> matches2 = Lists.newArrayList();
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("访问用户数")
.type(SchemaElementType.METRIC)
.build())
.similarity(1)
.build());
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("用户")
.type(SchemaElementType.DIMENSION)
.build())
.similarity(1)
.build());
dataSet2Matches.put(2L, matches2);
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
assert resolvedDataset == 1L;
}
@Test
public void testMaxMetricSimilarity() {
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
ChatQueryContext chatQueryContext = new ChatQueryContext();
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
List<SchemaElementMatch> matches = Lists.newArrayList();
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("访问次数")
.type(SchemaElementType.METRIC)
.build())
.similarity(1)
.build());
dataSet2Matches.put(1L, matches);
List<SchemaElementMatch> matches2 = Lists.newArrayList();
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("访问用户数")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.6)
.build());
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("用户")
.type(SchemaElementType.DIMENSION)
.build())
.similarity(1)
.build());
dataSet2Matches.put(2L, matches2);
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
assert resolvedDataset == 1L;
}
@Test
public void testTotalSimilarity() {
Set<Long> dataSets = Sets.newHashSet(1L, 2L);
ChatQueryContext chatQueryContext = new ChatQueryContext();
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
List<SchemaElementMatch> matches = Lists.newArrayList();
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("访问次数")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.8)
.build());
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("部门")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.7)
.build());
dataSet2Matches.put(1L, matches);
List<SchemaElementMatch> matches2 = Lists.newArrayList();
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("访问用户数")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.8)
.build());
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("用户")
.type(SchemaElementType.DIMENSION)
.build())
.similarity(1)
.build());
dataSet2Matches.put(2L, matches2);
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
assert resolvedDataset == 2L;
}
}