[improvement][chat] Filter at the lowest level in the Map based on the dataSetId (#1834)

This commit is contained in:
lexluo09
2024-10-20 21:51:57 +08:00
committed by GitHub
parent 1d84e00887
commit 473329d398
108 changed files with 232 additions and 165 deletions

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;

View File

@@ -16,6 +16,7 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -51,7 +52,9 @@ public class SearchService {
public static List<HanlpMapResult> prefixSearch(String key, int limit,
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
Set<Map.Entry<String, List<String>>> result = search(key, binTrie);
Set<Long> modelIdOrDataSetIds =
findModelIdOrDataSetIds(modelIdToDataSetIds, detectDataSetIds);
Set<Map.Entry<String, List<String>>> result = search(key, binTrie, modelIdOrDataSetIds);
List<HanlpMapResult> hanlpMapResults = result.stream().map(entry -> {
String name = entry.getKey().replace("#", " ");
double similarity = EditDistanceUtils.getSimilarity(name, key);
@@ -77,7 +80,11 @@ public class SearchService {
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
String reverseDetectSegment = StringUtils.reverse(key);
Set<Map.Entry<String, List<String>>> result = search(reverseDetectSegment, binTrie);
Set<Long> modelIdOrDataSetIds =
findModelIdOrDataSetIds(modelIdToDataSetIds, detectDataSetIds);
Set<Map.Entry<String, List<String>>> result =
search(reverseDetectSegment, binTrie, modelIdOrDataSetIds);
List<HanlpMapResult> hanlpMapResults = result.stream().map(entry -> {
String name = entry.getKey().replace("#", " ");
List<String> natures = entry.getValue().stream()
@@ -115,7 +122,7 @@ public class SearchService {
}
private static Set<Map.Entry<String, List<String>>> search(String key,
BinTrie<List<String>> binTrie) {
BinTrie<List<String>> binTrie, Set<Long> modelIdOrDataSetIds) {
key = key.toLowerCase();
Set<Map.Entry<String, List<String>>> entrySet =
new TreeSet<Map.Entry<String, List<String>>>();
@@ -136,7 +143,7 @@ public class SearchService {
if (branch == null) {
return entrySet;
}
branch.walkLimit(sb, entrySet);
branch.walkLimit(sb, entrySet, modelIdOrDataSetIds);
return entrySet;
}
@@ -199,4 +206,23 @@ public class SearchService {
}
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
}
/**
* Find all modelIds and dataSetIds based on the dataSetId
*/
public static Set<Long> findModelIdOrDataSetIds(Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
if (CollectionUtils.isEmpty(detectDataSetIds)) {
return new HashSet<>();
}
if (CollectionUtils.isEmpty(modelIdToDataSetIds)) {
return new HashSet<>(detectDataSetIds);
}
Set<Long> result = modelIdToDataSetIds.entrySet().stream()
.filter(entry -> entry.getValue().stream().anyMatch(detectDataSetIds::contains))
.map(Map.Entry::getKey).collect(Collectors.toSet());
result.addAll(detectDataSetIds);
return result;
}
}

View File

@@ -91,7 +91,7 @@ public class NatureHelper {
public static List<String> changeModel2DataSet(String nature,
Map<Long, List<Long>> modelIdToDataSetIds) {
if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) {
if (isTerm(nature)) {
return Collections.singletonList(nature);
}
Long modelId = getModelId(nature);
@@ -103,6 +103,10 @@ public class NatureHelper {
.filter(Objects::nonNull).map(String::valueOf).collect(Collectors.toList());
}
public static boolean isTerm(String nature) {
return SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature));
}
public static boolean isDimensionValueDataSetId(String nature) {
return isNatureValid(nature)
&& !isNatureType(nature, DictWordType.METRIC, DictWordType.DIMENSION,

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.headless.chat.parser;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;