(improvement)(chat) queryDimensionValue query top10 in queue (#277)

This commit is contained in:
lexluo09
2023-10-23 20:38:13 +08:00
committed by GitHub
parent 62e2bf7de6
commit 4fbc3c8533
6 changed files with 56 additions and 17 deletions

View File

@@ -1,16 +1,21 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import javax.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
@Data @Data
public class DimensionValueReq { public class DimensionValueReq {
private Integer agentId; private Integer agentId;
@NotNull
private Long elementID; private Long elementID;
@NotNull
private Long modelId; private Long modelId;
private String bizName; private String bizName;
@NotNull
private Object value; private Object value;
} }

View File

@@ -94,8 +94,4 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return semanticSchema.getMetrics(modelId); return semanticSchema.getMetrics(modelId);
} }
protected List<SchemaElement> getDimensionElements(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
return semanticSchema.getDimensions(modelId);
}
} }

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -32,9 +33,20 @@ public class GroupByCorrector extends BaseSemanticCorrector {
//add dimension group by //add dimension group by
String sql = semanticCorrectInfo.getSql(); String sql = semanticCorrectInfo.getSql();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
//add alias field name
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream() Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet()); .flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(DateUtils.DATE_FIELD); dimensions.add(DateUtils.DATE_FIELD);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql); List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) { if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.SearchService; import com.tencent.supersonic.chat.service.SearchService;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
@@ -78,7 +79,7 @@ public class ChatQueryController {
} }
@PostMapping("queryDimensionValue") @PostMapping("queryDimensionValue")
public Object queryDimensionValue(@RequestBody DimensionValueReq dimensionValueReq, public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
HttpServletRequest request, HttpServletResponse response) HttpServletRequest request, HttpServletResponse response)
throws Exception { throws Exception {
return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response)); return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.service.impl; package com.tencent.supersonic.chat.service.impl;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SchemaMapper; import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
@@ -37,6 +38,7 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.SolvedQueryManager; import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
@@ -46,6 +48,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
import com.tencent.supersonic.knowledge.service.SearchService; import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.knowledge.utils.NatureHelper; import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.knowledge.utils.HanlpHelper; import com.tencent.supersonic.knowledge.utils.HanlpHelper;
@@ -59,6 +62,7 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -590,8 +594,14 @@ public class QueryServiceImpl implements QueryService {
@Override @Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception { public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
if (Objects.isNull(dimensionValueReq.getValue())) { if (StringUtils.isBlank(dimensionValueReq.getValue().toString())) {
dimensionValueReq.setValue(""); String nature =
dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT + dimensionValueReq.getElementID();
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
if (CollectionUtils.isEmpty(terms)) {
return null;
}
return terms.stream().map(term -> term.getWord()).collect(Collectors.toSet());
} }
return queryHanlpDimensionValue(dimensionValueReq, user); return queryHanlpDimensionValue(dimensionValueReq, user);
} }

View File

@@ -11,12 +11,12 @@ import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary; import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary; import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
import com.hankcs.hanlp.dictionary.other.CharTable; import com.hankcs.hanlp.dictionary.other.CharTable;
import com.hankcs.hanlp.seg.common.Term;
import com.hankcs.hanlp.utility.LexiconUtility; import com.hankcs.hanlp.utility.LexiconUtility;
import com.hankcs.hanlp.utility.Predefine; import com.hankcs.hanlp.utility.Predefine;
import com.hankcs.hanlp.utility.TextUtility; import com.hankcs.hanlp.utility.TextUtility;
import com.tencent.supersonic.knowledge.service.SearchService; import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper; import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.io.BufferedOutputStream; import java.io.BufferedOutputStream;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.DataOutputStream; import java.io.DataOutputStream;
@@ -25,15 +25,21 @@ import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
public class MultiCustomDictionary extends DynamicCustomDictionary { public class MultiCustomDictionary extends DynamicCustomDictionary {
public static int MAX_SIZE = 10;
public static Boolean removeDuplicates = true; public static Boolean removeDuplicates = true;
public static ConcurrentHashMap<String, PriorityQueue<Term>> NATURE_TO_VALUES = new ConcurrentHashMap<>();
private static boolean addToSuggesterTrie = true; private static boolean addToSuggesterTrie = true;
public MultiCustomDictionary() { public MultiCustomDictionary() {
@@ -107,17 +113,26 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
} }
} }
attribute.original = original; attribute.original = original;
if (removeDuplicates && map.containsKey(word)) { if (removeDuplicates && map.containsKey(word)) {
attribute = DictionaryAttributeUtil.getAttribute(map.get(word), attribute); attribute = DictionaryAttributeUtil.getAttribute(map.get(word), attribute);
}
map.put(word, attribute); map.put(word, attribute);
if (addToSuggeterTrie) { if (addToSuggeterTrie) {
SearchService.put(word, attribute); SearchService.put(word, attribute);
} }
for (int i = 0; i < attribute.nature.length; i++) {
} else { Nature nature = attribute.nature[i];
map.put(word, attribute); PriorityQueue<Term> priorityQueue = NATURE_TO_VALUES.get(nature.toString());
if (addToSuggeterTrie) { if (Objects.isNull(priorityQueue)) {
SearchService.put(word, attribute); priorityQueue = new PriorityQueue<>(MAX_SIZE,
Comparator.comparingInt(Term::getFrequency).reversed());
NATURE_TO_VALUES.put(nature.toString(), priorityQueue);
}
Term term = new Term(word, nature);
term.setFrequency(attribute.frequency[i]);
if (!priorityQueue.contains(term) && priorityQueue.size() < MAX_SIZE) {
priorityQueue.add(term);
} }
} }
} }