This commit is contained in:
QJ_wonder
2025-05-26 11:03:37 +08:00
committed by GitHub
4 changed files with 77 additions and 81 deletions

View File

@@ -122,6 +122,11 @@
<version>${mockito-inline.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.huaban</groupId>
<artifactId>jieba-analysis</artifactId>
<version>${jieba.version}</version>
</dependency>
</dependencies>
</project>

View File

@@ -1,98 +1,31 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import static com.tencent.supersonic.headless.chat.parser.llm.TextSimilarityCalculation.getDataSetSimilarity;
/**
* HeuristicDataSetResolver select ONE most suitable data set out of matched data sets. The
* selection is based on similarity comparison rule and the priority is like: 1.
* maxSimilarity(matched dataset) 2. maxSimilarity(all matched metrics) 3. totalSimilarity(all
* matched elements)
* HeuristicDataSetResolver select ONE most suitable data set out of data sets. The
* selection is based on the cosine similarity directly between the question text and the dataset name
*/
@Slf4j
public class HeuristicDataSetResolver implements DataSetResolver {
public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
matchedDataSets.retainAll(agentDataSetIds);
String queryText = chatQueryContext.getRequest().getQueryText();
List<SchemaElement> dataSets = chatQueryContext.getSemanticSchema().getDataSets();
if(dataSets.size() == 1){
return dataSets.get(0).getDataSetId();
}
if (matchedDataSets.size() == 1) {
return matchedDataSets.stream().findFirst().get();
Map<Long,Double> dataSetSimilarity = new LinkedHashMap<>();
for (SchemaElement dataSet : dataSets){
dataSetSimilarity.put(dataSet.getDataSetId(),getDataSetSimilarity(queryText,dataSet.getDataSetName()));
}
return selectDataSetByMatchSimilarity(mapInfo);
}
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
Map<Long, SemanticParseInfo.DataSetMatchResult> dataSetMatchRet =
getDataSetMatchResult(schemaMap);
Entry<Long, SemanticParseInfo.DataSetMatchResult> selectedDataset =
dataSetMatchRet.entrySet().stream().sorted((o1, o2) -> {
double difference = o1.getValue().getMaxDatesetSimilarity()
- o2.getValue().getMaxDatesetSimilarity();
if (difference == 0) {
difference = o1.getValue().getMaxMetricSimilarity()
- o2.getValue().getMaxMetricSimilarity();
if (difference == 0) {
difference = o1.getValue().getTotalSimilarity()
- o2.getValue().getTotalSimilarity();
}
if (difference == 0) {
difference = o1.getValue().getMaxMetricUseCnt()
- o2.getValue().getMaxMetricUseCnt();
}
}
return difference >= 0 ? -1 : 1;
}).findFirst().orElse(null);
if (selectedDataset != null) {
log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey());
return selectedDataset.getKey();
}
return null;
}
protected Map<Long, SemanticParseInfo.DataSetMatchResult> getDataSetMatchResult(
SchemaMapInfo schemaMap) {
Map<Long, SemanticParseInfo.DataSetMatchResult> dateSetMatchRet = new HashMap<>();
for (Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches()
.entrySet()) {
double maxMetricSimilarity = 0;
double maxDatasetSimilarity = 0;
double totalSimilarity = 0;
long maxMetricUseCnt = 0L;
for (SchemaElementMatch match : entry.getValue()) {
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
}
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
if (Objects.nonNull(match.getElement().getUseCnt())) {
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
}
}
totalSimilarity += match.getSimilarity();
}
dateSetMatchRet.put(entry.getKey(),
SemanticParseInfo.DataSetMatchResult.builder()
.maxMetricSimilarity(maxMetricSimilarity)
.maxDatesetSimilarity(maxDatasetSimilarity)
.totalSimilarity(totalSimilarity).build());
}
return dateSetMatchRet;
return dataSetSimilarity.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey();
}
}

View File

@@ -0,0 +1,52 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.huaban.analysis.jieba.JiebaSegmenter;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
@Slf4j
public class TextSimilarityCalculation {
// 生成词频向量
private static double[] createVector(List<String> words, List<String> vocabulary) {
double[] vector = new double[vocabulary.size()];
Map<String, Integer> wordFreq = new HashMap<>();
for (String word : words) {
wordFreq.put(word, wordFreq.getOrDefault(word, 0) + 1);
}
for (int i = 0; i < vocabulary.size(); i++) {
vector[i] = wordFreq.getOrDefault(vocabulary.get(i), 0);
}
return vector;
}
// 余弦相似度计算公式
private static double cosineSimilarity(double[] vecA, double[] vecB) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vecA.length; i++) {
dotProduct += vecA[i] * vecB[i];
normA += Math.pow(vecA[i], 2);
normB += Math.pow(vecB[i], 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
public static double getDataSetSimilarity(String queryText, String datasetName){
if(queryText ==null || datasetName == null ){ return 0.0;}
JiebaSegmenter segmenter = new JiebaSegmenter();
// 1.分词
List<String> words1 = segmenter.sentenceProcess(queryText);
List<String> words2 = segmenter.sentenceProcess(datasetName);
// 2. 构建词汇表并生成向量
List<String> vocabulary = new ArrayList<>(new HashSet<>(words1));
vocabulary.addAll(new HashSet<>(words2));
double[] vector1 = createVector(words1, vocabulary);
double[] vector2 = createVector(words2, vocabulary);
// 计算相似度(示例使用简单重叠度计算)
double similarity = cosineSimilarity(vector1, vector2);
return similarity;
}
}

View File

@@ -82,6 +82,7 @@
<stax2.version>4.2.2</stax2.version>
<aws-java-sdk.version>1.12.780</aws-java-sdk.version>
<jgrapht.version>1.5.2</jgrapht.version>
<jieba.version>1.0.2</jieba.version>
</properties>
<dependencyManagement>
@@ -216,6 +217,11 @@
<artifactId>jgrapht-core</artifactId>
<version>${jgrapht.version}</version>
</dependency>
<dependency>
<groupId>com.huaban</groupId>
<artifactId>jieba-analysis</artifactId>
<version>${jieba.version}</version>
</dependency>
</dependencies>
</dependencyManagement>