diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetMatchResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetMatchResult.java index bb612b6f4..7d5eb265f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetMatchResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetMatchResult.java @@ -1,9 +1,12 @@ package com.tencent.supersonic.headless.chat.parser.llm; +import lombok.Builder; import lombok.Data; @Data +@Builder public class DataSetMatchResult { - private Integer count = 0; - private double maxSimilarity; + private double maxMetricSimilarity; + private double maxDatesetSimilarity; + private double totalSimilarity; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java index 12d518d95..b745b9925 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java @@ -4,157 +4,92 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; -import com.tencent.supersonic.headless.chat.query.SemanticQuery; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; -import java.util.ArrayList; -import java.util.Comparator; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; -import java.util.Objects; import java.util.Set; -import java.util.stream.Collectors; +/** + * HeuristicDataSetResolver select ONE most suitable data set out of matched data sets. The + * selection is based on similarity comparison rule and the priority is like: 1. + * maxSimilarity(matched dataset) 2. maxSimilarity(all matched metrics) 3. totalSimilarity(all + * matched elements) + */ @Slf4j public class HeuristicDataSetResolver implements DataSetResolver { - protected static Long selectDataSetBySchemaElementMatchScore( - Map dataSetQueryModes, SchemaMapInfo schemaMap) { - // dataSet count priority - Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap); - if (Objects.nonNull(dataSetIdByDataSetCount)) { - log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount); - return dataSetIdByDataSetCount; - } - - Map dataSetTypeMap = getDataSetTypeMap(schemaMap); - if (dataSetTypeMap.size() == 1) { - Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey(); - if (dataSetQueryModes.containsKey(dataSetSelect)) { - log.info("selectDataSet with only one DataSet [{}]", dataSetSelect); - return dataSetSelect; - } - } else { - Entry maxDataSet = - dataSetTypeMap.entrySet().stream() - .filter(entry -> dataSetQueryModes.containsKey(entry.getKey())) - .sorted( - (o1, o2) -> { - int difference = - o2.getValue().getCount() - o1.getValue().getCount(); - if (difference == 0) { - return (int) - ((o2.getValue().getMaxSimilarity() - - o1.getValue() - .getMaxSimilarity()) - * 100); - } - return difference; - }) - .findFirst() - .orElse(null); - if (maxDataSet != null) { - log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey()); - return maxDataSet.getKey(); - } - } - return null; - } - - private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) { - Map> dataSetElementMatches = - schemaMap.getDataSetElementMatches(); - // calculate dataSet match score, matched element gets 1.0 point, and inherit element gets - // 0.5 point - Map dataSetIdToDataSetScore = new HashMap<>(); - if (Objects.nonNull(dataSetElementMatches)) { - for (Entry> dataSetElementMatch : - dataSetElementMatches.entrySet()) { - Long dataSetId = dataSetElementMatch.getKey(); - List dataSetMatchesScore = - dataSetElementMatch.getValue().stream() - .filter(elementMatch -> elementMatch.getSimilarity() >= 1) - .filter( - elementMatch -> - SchemaElementType.DATASET.equals( - elementMatch.getElement().getType())) - .map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0) - .collect(Collectors.toList()); - - if (!CollectionUtils.isEmpty(dataSetMatchesScore)) { - // get sum of dataSet match score - double score = - dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum(); - dataSetIdToDataSetScore.put(dataSetId, score); - } - } - Entry maxDataSetScore = - dataSetIdToDataSetScore.entrySet().stream() - .max(Comparator.comparingDouble(Entry::getValue)) - .orElse(null); - log.info( - "maxDataSetCount:{},dataSetIdToDataSetCount:{}", - maxDataSetScore, - dataSetIdToDataSetScore); - if (Objects.nonNull(maxDataSetScore)) { - return maxDataSetScore.getKey(); - } - } - return null; - } - - public static Map getDataSetTypeMap(SchemaMapInfo schemaMap) { - Map dataSetCount = new HashMap<>(); - for (Entry> entry : - schemaMap.getDataSetElementMatches().entrySet()) { - List schemaElementMatches = - schemaMap.getMatchedElements(entry.getKey()); - if (schemaElementMatches != null && schemaElementMatches.size() > 0) { - if (!dataSetCount.containsKey(entry.getKey())) { - dataSetCount.put(entry.getKey(), new DataSetMatchResult()); - } - DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey()); - Set schemaElementTypes = new HashSet<>(); - schemaElementMatches.stream() - .forEach( - schemaElementMatch -> - schemaElementTypes.add( - schemaElementMatch.getElement().getType())); - SchemaElementMatch schemaElementMatchMax = - schemaElementMatches.stream() - .sorted( - (o1, o2) -> - ((int) - ((o2.getSimilarity() - o1.getSimilarity()) - * 100))) - .findFirst() - .orElse(null); - if (schemaElementMatchMax != null) { - dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity()); - } - dataSetMatchResult.setCount(schemaElementTypes.size()); - } - } - return dataSetCount; - } - public Long resolve(ChatQueryContext chatQueryContext, Set agentDataSetIds) { SchemaMapInfo mapInfo = chatQueryContext.getMapInfo(); Set matchedDataSets = mapInfo.getMatchedDataSetInfos(); if (CollectionUtils.isNotEmpty(agentDataSetIds)) { matchedDataSets.retainAll(agentDataSetIds); } - Map dataSetQueryModes = new HashMap<>(); - for (Long dataSetIds : matchedDataSets) { - dataSetQueryModes.put(dataSetIds, null); + if (matchedDataSets.size() == 1) { + return matchedDataSets.stream().findFirst().get(); } - if (dataSetQueryModes.size() == 1) { - return dataSetQueryModes.keySet().stream().findFirst().get(); + return selectDataSetByMatchSimilarity(mapInfo); + } + + protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) { + Map dataSetMatchRet = getDataSetMatchResult(schemaMap); + Entry selectedDataset = + dataSetMatchRet.entrySet().stream() + .sorted( + (o1, o2) -> { + double difference = + o1.getValue().getMaxDatesetSimilarity() + - o2.getValue().getMaxDatesetSimilarity(); + if (difference == 0) { + difference = + o1.getValue().getMaxMetricSimilarity() + - o2.getValue().getMaxMetricSimilarity(); + if (difference == 0) { + difference = + o1.getValue().getTotalSimilarity() + - o2.getValue().getTotalSimilarity(); + } + } + return difference >= 0 ? -1 : 1; + }) + .findFirst() + .orElse(null); + if (selectedDataset != null) { + log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey()); + return selectedDataset.getKey(); } - return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo); + + return null; + } + + protected Map getDataSetMatchResult(SchemaMapInfo schemaMap) { + Map dateSetMatchRet = new HashMap<>(); + for (Entry> entry : + schemaMap.getDataSetElementMatches().entrySet()) { + double maxMetricSimilarity = 0; + double maxDatasetSimilarity = 0; + double totalSimilarity = 0; + for (SchemaElementMatch match : entry.getValue()) { + if (SchemaElementType.DATASET.equals(match.getElement().getType())) { + maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity()); + } + if (SchemaElementType.METRIC.equals(match.getElement().getType())) { + maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity()); + } + totalSimilarity += match.getSimilarity(); + } + dateSetMatchRet.put( + entry.getKey(), + DataSetMatchResult.builder() + .maxMetricSimilarity(maxMetricSimilarity) + .maxDatesetSimilarity(maxDatasetSimilarity) + .totalSimilarity(totalSimilarity) + .build()); + } + + return dateSetMatchRet; } } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/parser/HeuristicDataSetResolverTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/parser/HeuristicDataSetResolverTest.java new file mode 100644 index 000000000..854fa4471 --- /dev/null +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/parser/HeuristicDataSetResolverTest.java @@ -0,0 +1,178 @@ +package com.tencent.supersonic.headless.chat.parser; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver; +import com.tencent.supersonic.headless.chat.parser.llm.HeuristicDataSetResolver; +import org.junit.Test; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class HeuristicDataSetResolverTest { + + private DataSetResolver resolver = new HeuristicDataSetResolver(); + + @Test + public void testMaxDatasetSimilarity() { + Set dataSets = Sets.newHashSet(1L, 2L); + ChatQueryContext chatQueryContext = new ChatQueryContext(); + Map> dataSet2Matches = + chatQueryContext.getMapInfo().getDataSetElementMatches(); + List matches = Lists.newArrayList(); + matches.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(1L) + .name("超音数") + .type(SchemaElementType.DATASET) + .build()) + .similarity(1) + .build()); + matches.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(1L) + .name("访问次数") + .type(SchemaElementType.METRIC) + .build()) + .similarity(0.5) + .build()); + dataSet2Matches.put(1L, matches); + + List matches2 = Lists.newArrayList(); + matches2.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(2L) + .name("访问用户数") + .type(SchemaElementType.METRIC) + .build()) + .similarity(1) + .build()); + matches2.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(2L) + .name("用户") + .type(SchemaElementType.DIMENSION) + .build()) + .similarity(1) + .build()); + dataSet2Matches.put(2L, matches2); + + Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets); + assert resolvedDataset == 1L; + } + + @Test + public void testMaxMetricSimilarity() { + Set dataSets = Sets.newHashSet(1L, 2L); + ChatQueryContext chatQueryContext = new ChatQueryContext(); + Map> dataSet2Matches = + chatQueryContext.getMapInfo().getDataSetElementMatches(); + List matches = Lists.newArrayList(); + matches.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(1L) + .name("访问次数") + .type(SchemaElementType.METRIC) + .build()) + .similarity(1) + .build()); + dataSet2Matches.put(1L, matches); + + List matches2 = Lists.newArrayList(); + matches2.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(2L) + .name("访问用户数") + .type(SchemaElementType.METRIC) + .build()) + .similarity(0.6) + .build()); + matches2.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(2L) + .name("用户") + .type(SchemaElementType.DIMENSION) + .build()) + .similarity(1) + .build()); + dataSet2Matches.put(2L, matches2); + + Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets); + assert resolvedDataset == 1L; + } + + @Test + public void testTotalSimilarity() { + Set dataSets = Sets.newHashSet(1L, 2L); + ChatQueryContext chatQueryContext = new ChatQueryContext(); + Map> dataSet2Matches = + chatQueryContext.getMapInfo().getDataSetElementMatches(); + List matches = Lists.newArrayList(); + matches.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(1L) + .name("访问次数") + .type(SchemaElementType.METRIC) + .build()) + .similarity(0.8) + .build()); + matches.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(1L) + .name("部门") + .type(SchemaElementType.METRIC) + .build()) + .similarity(0.7) + .build()); + dataSet2Matches.put(1L, matches); + + List matches2 = Lists.newArrayList(); + matches2.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(2L) + .name("访问用户数") + .type(SchemaElementType.METRIC) + .build()) + .similarity(0.8) + .build()); + matches2.add( + SchemaElementMatch.builder() + .element( + SchemaElement.builder() + .dataSetId(2L) + .name("用户") + .type(SchemaElementType.DIMENSION) + .build()) + .similarity(1) + .build()); + dataSet2Matches.put(2L, matches2); + + Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets); + assert resolvedDataset == 2L; + } +}