mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
[improvement][chat] Filter at the lowest level in the Map based on the dataSetId (#1834)
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
|
||||
@@ -16,6 +16,7 @@ import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -51,7 +52,9 @@ public class SearchService {
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit,
|
||||
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = search(key, binTrie);
|
||||
Set<Long> modelIdOrDataSetIds =
|
||||
findModelIdOrDataSetIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
Set<Map.Entry<String, List<String>>> result = search(key, binTrie, modelIdOrDataSetIds);
|
||||
List<HanlpMapResult> hanlpMapResults = result.stream().map(entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
double similarity = EditDistanceUtils.getSimilarity(name, key);
|
||||
@@ -77,7 +80,11 @@ public class SearchService {
|
||||
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String reverseDetectSegment = StringUtils.reverse(key);
|
||||
Set<Map.Entry<String, List<String>>> result = search(reverseDetectSegment, binTrie);
|
||||
Set<Long> modelIdOrDataSetIds =
|
||||
findModelIdOrDataSetIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
Set<Map.Entry<String, List<String>>> result =
|
||||
search(reverseDetectSegment, binTrie, modelIdOrDataSetIds);
|
||||
List<HanlpMapResult> hanlpMapResults = result.stream().map(entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
List<String> natures = entry.getValue().stream()
|
||||
@@ -115,7 +122,7 @@ public class SearchService {
|
||||
}
|
||||
|
||||
private static Set<Map.Entry<String, List<String>>> search(String key,
|
||||
BinTrie<List<String>> binTrie) {
|
||||
BinTrie<List<String>> binTrie, Set<Long> modelIdOrDataSetIds) {
|
||||
key = key.toLowerCase();
|
||||
Set<Map.Entry<String, List<String>>> entrySet =
|
||||
new TreeSet<Map.Entry<String, List<String>>>();
|
||||
@@ -136,7 +143,7 @@ public class SearchService {
|
||||
if (branch == null) {
|
||||
return entrySet;
|
||||
}
|
||||
branch.walkLimit(sb, entrySet);
|
||||
branch.walkLimit(sb, entrySet, modelIdOrDataSetIds);
|
||||
return entrySet;
|
||||
}
|
||||
|
||||
@@ -199,4 +206,23 @@ public class SearchService {
|
||||
}
|
||||
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Find all modelIds and dataSetIds based on the dataSetId
|
||||
*/
|
||||
public static Set<Long> findModelIdOrDataSetIds(Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
if (CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
if (CollectionUtils.isEmpty(modelIdToDataSetIds)) {
|
||||
return new HashSet<>(detectDataSetIds);
|
||||
}
|
||||
Set<Long> result = modelIdToDataSetIds.entrySet().stream()
|
||||
.filter(entry -> entry.getValue().stream().anyMatch(detectDataSetIds::contains))
|
||||
.map(Map.Entry::getKey).collect(Collectors.toSet());
|
||||
|
||||
result.addAll(detectDataSetIds);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ public class NatureHelper {
|
||||
|
||||
public static List<String> changeModel2DataSet(String nature,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) {
|
||||
if (isTerm(nature)) {
|
||||
return Collections.singletonList(nature);
|
||||
}
|
||||
Long modelId = getModelId(nature);
|
||||
@@ -103,6 +103,10 @@ public class NatureHelper {
|
||||
.filter(Objects::nonNull).map(String::valueOf).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static boolean isTerm(String nature) {
|
||||
return SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature));
|
||||
}
|
||||
|
||||
public static boolean isDimensionValueDataSetId(String nature) {
|
||||
return isNatureValid(nature)
|
||||
&& !isNatureType(nature, DictWordType.METRIC, DictWordType.DIMENSION,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
|
||||
Reference in New Issue
Block a user