diff --git a/headless/chat/pom.xml b/headless/chat/pom.xml
index 9847f0624..71d6445f4 100644
--- a/headless/chat/pom.xml
+++ b/headless/chat/pom.xml
@@ -122,6 +122,11 @@
${mockito-inline.version}
test
+
+ com.huaban
+ jieba-analysis
+ ${jieba.version}
+
diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java
index c1f6591a1..33866d8c3 100644
--- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java
+++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java
@@ -1,98 +1,31 @@
package com.tencent.supersonic.headless.chat.parser.llm;
-import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
-import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
-import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
-import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
+import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
-import org.apache.commons.collections.CollectionUtils;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
import java.util.Map.Entry;
-import java.util.Objects;
-import java.util.Set;
+
+import static com.tencent.supersonic.headless.chat.parser.llm.TextSimilarityCalculation.getDataSetSimilarity;
/**
- * HeuristicDataSetResolver select ONE most suitable data set out of matched data sets. The
- * selection is based on similarity comparison rule and the priority is like: 1.
- * maxSimilarity(matched dataset) 2. maxSimilarity(all matched metrics) 3. totalSimilarity(all
- * matched elements)
+ * HeuristicDataSetResolver select ONE most suitable data set out of data sets. The
+ * selection is based on the cosine similarity directly between the question text and the dataset name
*/
@Slf4j
public class HeuristicDataSetResolver implements DataSetResolver {
public Long resolve(ChatQueryContext chatQueryContext, Set agentDataSetIds) {
- SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
- Set matchedDataSets = mapInfo.getMatchedDataSetInfos();
- if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
- matchedDataSets.retainAll(agentDataSetIds);
+ String queryText = chatQueryContext.getRequest().getQueryText();
+ List dataSets = chatQueryContext.getSemanticSchema().getDataSets();
+ if(dataSets.size() == 1){
+ return dataSets.get(0).getDataSetId();
}
- if (matchedDataSets.size() == 1) {
- return matchedDataSets.stream().findFirst().get();
+ Map dataSetSimilarity = new LinkedHashMap<>();
+ for (SchemaElement dataSet : dataSets){
+ dataSetSimilarity.put(dataSet.getDataSetId(),getDataSetSimilarity(queryText,dataSet.getDataSetName()));
}
- return selectDataSetByMatchSimilarity(mapInfo);
- }
-
- protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
- Map dataSetMatchRet =
- getDataSetMatchResult(schemaMap);
- Entry selectedDataset =
- dataSetMatchRet.entrySet().stream().sorted((o1, o2) -> {
- double difference = o1.getValue().getMaxDatesetSimilarity()
- - o2.getValue().getMaxDatesetSimilarity();
- if (difference == 0) {
- difference = o1.getValue().getMaxMetricSimilarity()
- - o2.getValue().getMaxMetricSimilarity();
- if (difference == 0) {
- difference = o1.getValue().getTotalSimilarity()
- - o2.getValue().getTotalSimilarity();
- }
- if (difference == 0) {
- difference = o1.getValue().getMaxMetricUseCnt()
- - o2.getValue().getMaxMetricUseCnt();
- }
- }
- return difference >= 0 ? -1 : 1;
- }).findFirst().orElse(null);
- if (selectedDataset != null) {
- log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey());
- return selectedDataset.getKey();
- }
-
- return null;
- }
-
- protected Map getDataSetMatchResult(
- SchemaMapInfo schemaMap) {
- Map dateSetMatchRet = new HashMap<>();
- for (Entry> entry : schemaMap.getDataSetElementMatches()
- .entrySet()) {
- double maxMetricSimilarity = 0;
- double maxDatasetSimilarity = 0;
- double totalSimilarity = 0;
- long maxMetricUseCnt = 0L;
- for (SchemaElementMatch match : entry.getValue()) {
- if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
- maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
- }
- if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
- maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
- if (Objects.nonNull(match.getElement().getUseCnt())) {
- maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
- }
- }
- totalSimilarity += match.getSimilarity();
- }
- dateSetMatchRet.put(entry.getKey(),
- SemanticParseInfo.DataSetMatchResult.builder()
- .maxMetricSimilarity(maxMetricSimilarity)
- .maxDatesetSimilarity(maxDatasetSimilarity)
- .totalSimilarity(totalSimilarity).build());
- }
-
- return dateSetMatchRet;
+ return dataSetSimilarity.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey();
}
}
diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/TextSimilarityCalculation.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/TextSimilarityCalculation.java
new file mode 100644
index 000000000..4c2305abf
--- /dev/null
+++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/TextSimilarityCalculation.java
@@ -0,0 +1,52 @@
+package com.tencent.supersonic.headless.chat.parser.llm;
+
+import com.huaban.analysis.jieba.JiebaSegmenter;
+import lombok.extern.slf4j.Slf4j;
+
+import java.util.*;
+
+@Slf4j
+public class TextSimilarityCalculation {
+ // 生成词频向量
+ private static double[] createVector(List words, List vocabulary) {
+ double[] vector = new double[vocabulary.size()];
+ Map wordFreq = new HashMap<>();
+ for (String word : words) {
+ wordFreq.put(word, wordFreq.getOrDefault(word, 0) + 1);
+ }
+ for (int i = 0; i < vocabulary.size(); i++) {
+ vector[i] = wordFreq.getOrDefault(vocabulary.get(i), 0);
+ }
+ return vector;
+ }
+ // 余弦相似度计算公式
+ private static double cosineSimilarity(double[] vecA, double[] vecB) {
+ double dotProduct = 0.0;
+ double normA = 0.0;
+ double normB = 0.0;
+ for (int i = 0; i < vecA.length; i++) {
+ dotProduct += vecA[i] * vecB[i];
+ normA += Math.pow(vecA[i], 2);
+ normB += Math.pow(vecB[i], 2);
+ }
+ return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
+ }
+
+ public static double getDataSetSimilarity(String queryText, String datasetName){
+ if(queryText ==null || datasetName == null ){ return 0.0;}
+ JiebaSegmenter segmenter = new JiebaSegmenter();
+
+ // 1.分词
+ List words1 = segmenter.sentenceProcess(queryText);
+ List words2 = segmenter.sentenceProcess(datasetName);
+ // 2. 构建词汇表并生成向量
+ List vocabulary = new ArrayList<>(new HashSet<>(words1));
+ vocabulary.addAll(new HashSet<>(words2));
+
+ double[] vector1 = createVector(words1, vocabulary);
+ double[] vector2 = createVector(words2, vocabulary);
+ // 计算相似度(示例使用简单重叠度计算)
+ double similarity = cosineSimilarity(vector1, vector2);
+ return similarity;
+ }
+}
diff --git a/pom.xml b/pom.xml
index 8db6a8336..69940228b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -82,6 +82,7 @@
4.2.2
1.12.780
1.5.2
+ 1.0.2
@@ -216,6 +217,11 @@
jgrapht-core
${jgrapht.version}
+
+ com.huaban
+ jieba-analysis
+ ${jieba.version}
+