mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
[improvement][chat] Filter at the lowest level in the Map based on the dataSetId (#1834)
This commit is contained in:
@@ -3,38 +3,38 @@ package com.hankcs.hanlp;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
public class LoadRemoveService {
|
||||
|
||||
@Value("${s2.mapper.remove.nature.prefix:}")
|
||||
private String mapperRemoveNaturePrefix;
|
||||
|
||||
public List removeNatures(List value) {
|
||||
public List removeNatures(List value, Set<Long> modelIdOrDataSetIds) {
|
||||
if (CollectionUtils.isEmpty(value)) {
|
||||
return value;
|
||||
}
|
||||
List<String> resultList = new ArrayList<>(value);
|
||||
if (StringUtils.isNotBlank(mapperRemoveNaturePrefix)) {
|
||||
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
|
||||
resultList.removeIf(nature -> {
|
||||
if (Objects.isNull(nature)) {
|
||||
return false;
|
||||
}
|
||||
return nature.startsWith(mapperRemoveNaturePrefix);
|
||||
Long id = getId(nature);
|
||||
if (Objects.nonNull(id)) {
|
||||
return !modelIdOrDataSetIds.contains(id);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
public Long getDataSetId(String nature) {
|
||||
public Long getId(String nature) {
|
||||
try {
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length <= 1) {
|
||||
|
||||
@@ -20,16 +20,26 @@ import java.util.Set;
|
||||
@Slf4j
|
||||
public abstract class BaseNode<V> implements Comparable<BaseNode> {
|
||||
|
||||
/** 状态数组,方便读取的时候用 */
|
||||
/**
|
||||
* 状态数组,方便读取的时候用
|
||||
*/
|
||||
static final Status[] ARRAY_STATUS = Status.values();
|
||||
|
||||
/** 子节点 */
|
||||
/**
|
||||
* 子节点
|
||||
*/
|
||||
protected BaseNode[] child;
|
||||
/** 节点状态 */
|
||||
/**
|
||||
* 节点状态
|
||||
*/
|
||||
protected Status status;
|
||||
/** 节点代表的字符 */
|
||||
/**
|
||||
* 节点代表的字符
|
||||
*/
|
||||
protected char c;
|
||||
/** 节点代表的值 */
|
||||
/**
|
||||
* 节点代表的值
|
||||
*/
|
||||
protected V value;
|
||||
|
||||
protected String prefix = null;
|
||||
@@ -228,13 +238,21 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
|
||||
}
|
||||
|
||||
public enum Status {
|
||||
/** 未指定,用于删除词条 */
|
||||
/**
|
||||
* 未指定,用于删除词条
|
||||
*/
|
||||
UNDEFINED_0,
|
||||
/** 不是词语的结尾 */
|
||||
/**
|
||||
* 不是词语的结尾
|
||||
*/
|
||||
NOT_WORD_1,
|
||||
/** 是个词语的结尾,并且还可以继续 */
|
||||
/**
|
||||
* 是个词语的结尾,并且还可以继续
|
||||
*/
|
||||
WORD_MIDDLE_2,
|
||||
/** 是个词语的结尾,并且没有继续 */
|
||||
/**
|
||||
* 是个词语的结尾,并且没有继续
|
||||
*/
|
||||
WORD_END_3,
|
||||
}
|
||||
|
||||
@@ -257,10 +275,10 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
|
||||
+ ", value=" + value + ", prefix='" + prefix + '\'' + '}';
|
||||
}
|
||||
|
||||
public void walkNode(Set<Map.Entry<String, V>> entrySet) {
|
||||
public void walkNode(Set<Map.Entry<String, V>> entrySet, Set<Long> modelIdOrDataSetIds) {
|
||||
if (status == Status.WORD_MIDDLE_2 || status == Status.WORD_END_3) {
|
||||
log.debug("walkNode before:{}", value.toString());
|
||||
List natures = new LoadRemoveService().removeNatures((List) value);
|
||||
List natures = new LoadRemoveService().removeNatures((List) value, modelIdOrDataSetIds);
|
||||
String name = this.prefix != null ? this.prefix + c : "" + c;
|
||||
log.debug("walkNode name:{},after:{},natures:{}", name, (List) value, natures);
|
||||
entrySet.add(new TrieEntry(name, (V) natures));
|
||||
@@ -273,7 +291,8 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
|
||||
* @param sb
|
||||
* @param entrySet
|
||||
*/
|
||||
public void walkLimit(StringBuilder sb, Set<Map.Entry<String, V>> entrySet) {
|
||||
public void walkLimit(StringBuilder sb, Set<Map.Entry<String, V>> entrySet,
|
||||
Set<Long> modelIdOrDataSetIds) {
|
||||
Queue<BaseNode> queue = new ArrayDeque<>();
|
||||
this.prefix = sb.toString();
|
||||
queue.add(this);
|
||||
@@ -282,7 +301,7 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
|
||||
if (root == null) {
|
||||
continue;
|
||||
}
|
||||
root.walkNode(entrySet);
|
||||
root.walkNode(entrySet, modelIdOrDataSetIds);
|
||||
if (root.child == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import java.util.Date;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
public class ChatModel {
|
||||
private Integer id;
|
||||
|
||||
@@ -3,9 +3,10 @@ package com.tencent.supersonic.common.persistence.dataobject;
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import java.util.Date;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@TableName("s2_chat_model")
|
||||
public class ChatModelDO {
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.service;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatModelService {
|
||||
|
||||
@@ -9,14 +9,15 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.service.ChatModelService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModelDO>
|
||||
|
||||
@@ -22,7 +22,6 @@ public class ChatAppManager {
|
||||
|
||||
public static Optional<ChatApp> getApp(String appKey) {
|
||||
return chatApps.entrySet().stream().filter(e -> e.getKey().equals(appKey))
|
||||
.map(Map.Entry::getValue)
|
||||
.findFirst();
|
||||
.map(Map.Entry::getValue).findFirst();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user