[improvement][chat] Filter at the lowest level in the Map based on the dataSetId (#1834)

This commit is contained in:
lexluo09
2024-10-20 21:51:57 +08:00
committed by GitHub
parent 1d84e00887
commit 473329d398
108 changed files with 232 additions and 165 deletions

View File

@@ -3,38 +3,38 @@ package com.hankcs.hanlp;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@Data
@Slf4j
public class LoadRemoveService {
@Value("${s2.mapper.remove.nature.prefix:}")
private String mapperRemoveNaturePrefix;
public List removeNatures(List value) {
public List removeNatures(List value, Set<Long> modelIdOrDataSetIds) {
if (CollectionUtils.isEmpty(value)) {
return value;
}
List<String> resultList = new ArrayList<>(value);
if (StringUtils.isNotBlank(mapperRemoveNaturePrefix)) {
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
resultList.removeIf(nature -> {
if (Objects.isNull(nature)) {
return false;
}
return nature.startsWith(mapperRemoveNaturePrefix);
Long id = getId(nature);
if (Objects.nonNull(id)) {
return !modelIdOrDataSetIds.contains(id);
}
return false;
});
}
return resultList;
}
public Long getDataSetId(String nature) {
public Long getId(String nature) {
try {
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {

View File

@@ -20,16 +20,26 @@ import java.util.Set;
@Slf4j
public abstract class BaseNode<V> implements Comparable<BaseNode> {
/** 状态数组,方便读取的时候用 */
/**
* 状态数组,方便读取的时候用
*/
static final Status[] ARRAY_STATUS = Status.values();
/** 子节点 */
/**
* 子节点
*/
protected BaseNode[] child;
/** 节点状态 */
/**
* 节点状态
*/
protected Status status;
/** 节点代表的字符 */
/**
* 节点代表的字符
*/
protected char c;
/** 节点代表的值 */
/**
* 节点代表的值
*/
protected V value;
protected String prefix = null;
@@ -228,13 +238,21 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
}
public enum Status {
/** 未指定,用于删除词条 */
/**
* 未指定,用于删除词条
*/
UNDEFINED_0,
/** 不是词语的结尾 */
/**
* 不是词语的结尾
*/
NOT_WORD_1,
/** 是个词语的结尾,并且还可以继续 */
/**
* 是个词语的结尾,并且还可以继续
*/
WORD_MIDDLE_2,
/** 是个词语的结尾,并且没有继续 */
/**
* 是个词语的结尾,并且没有继续
*/
WORD_END_3,
}
@@ -257,10 +275,10 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
+ ", value=" + value + ", prefix='" + prefix + '\'' + '}';
}
public void walkNode(Set<Map.Entry<String, V>> entrySet) {
public void walkNode(Set<Map.Entry<String, V>> entrySet, Set<Long> modelIdOrDataSetIds) {
if (status == Status.WORD_MIDDLE_2 || status == Status.WORD_END_3) {
log.debug("walkNode before:{}", value.toString());
List natures = new LoadRemoveService().removeNatures((List) value);
List natures = new LoadRemoveService().removeNatures((List) value, modelIdOrDataSetIds);
String name = this.prefix != null ? this.prefix + c : "" + c;
log.debug("walkNode name:{},after:{},natures:{}", name, (List) value, natures);
entrySet.add(new TrieEntry(name, (V) natures));
@@ -273,7 +291,8 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
* @param sb
* @param entrySet
*/
public void walkLimit(StringBuilder sb, Set<Map.Entry<String, V>> entrySet) {
public void walkLimit(StringBuilder sb, Set<Map.Entry<String, V>> entrySet,
Set<Long> modelIdOrDataSetIds) {
Queue<BaseNode> queue = new ArrayDeque<>();
this.prefix = sb.toString();
queue.add(this);
@@ -282,7 +301,7 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
if (root == null) {
continue;
}
root.walkNode(entrySet);
root.walkNode(entrySet, modelIdOrDataSetIds);
if (root.child == null) {
continue;
}

View File

@@ -1,9 +1,10 @@
package com.tencent.supersonic.common.config;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import java.util.Date;
import lombok.Data;
import java.util.Date;
@Data
public class ChatModel {
private Integer id;

View File

@@ -3,9 +3,10 @@ package com.tencent.supersonic.common.persistence.dataobject;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import java.util.Date;
import lombok.Data;
import java.util.Date;
@Data
@TableName("s2_chat_model")
public class ChatModelDO {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.service;
import com.tencent.supersonic.common.config.ChatModel;
import com.tencent.supersonic.common.pojo.User;
import java.util.List;
public interface ChatModelService {

View File

@@ -9,14 +9,15 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
@Service
public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModelDO>

View File

@@ -22,7 +22,6 @@ public class ChatAppManager {
public static Optional<ChatApp> getApp(String appKey) {
return chatApps.entrySet().stream().filter(e -> e.getKey().equals(appKey))
.map(Map.Entry::getValue)
.findFirst();
.map(Map.Entry::getValue).findFirst();
}
}