From 24b0be756691f3939ac0ce7e4fba416f5f236e61 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Fri, 10 Nov 2023 21:47:20 +0800 Subject: [PATCH] (improvement)(chat) when added to SchemaMap, duplicate elements are removed (#365) --- .../chat/api/pojo/SchemaElement.java | 6 ++--- .../supersonic/chat/mapper/BaseMapper.java | 23 +++++++++++++++++-- .../chat/mapper/EmbeddingMapper.java | 8 +++---- .../chat/service/impl/QueryServiceImpl.java | 1 - 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SchemaElement.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SchemaElement.java index eb929e3d4..10a3c54d9 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SchemaElement.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SchemaElement.java @@ -41,13 +41,13 @@ public class SchemaElement implements Serializable { SchemaElement schemaElement = (SchemaElement) o; return Objects.equal(model, schemaElement.model) && Objects.equal(id, schemaElement.id) && Objects.equal(name, schemaElement.name) - && Objects.equal(bizName, schemaElement.bizName) && Objects.equal( - useCnt, schemaElement.useCnt) && Objects.equal(type, schemaElement.type); + && Objects.equal(bizName, schemaElement.bizName) + && Objects.equal(type, schemaElement.type); } @Override public int hashCode() { - return Objects.hashCode(model, id, name, bizName, useCnt, type); + return Objects.hashCode(model, id, name, bizName, type); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java index a8040e3ef..4388e59db 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/BaseMapper.java @@ -13,6 +13,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @@ -44,13 +45,31 @@ public abstract class BaseMapper implements SchemaMapper { public abstract void doMap(QueryContext queryContext); - public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch schemaElementMatch) { + public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) { Map> modelElementMatches = schemaMap.getModelElementMatches(); List schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>()); if (schemaElementMatches == null) { schemaElementMatches = modelElementMatches.get(modelId); } - schemaElementMatches.add(schemaElementMatch); + //remove duplication + AtomicBoolean needAddNew = new AtomicBoolean(true); + schemaElementMatches.removeIf( + existElementMatch -> { + SchemaElement existElement = existElementMatch.getElement(); + SchemaElement newElement = newElementMatch.getElement(); + if (existElement.equals(newElement)) { + if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) { + return true; + } else { + needAddNew.set(false); + } + } + return false; + } + ); + if (needAddNew.get()) { + schemaElementMatches.add(newElementMatch); + } } public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java index 6c41881ed..c7cafc9fa 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java @@ -33,18 +33,16 @@ public class EmbeddingMapper extends BaseMapper { HanlpHelper.transLetterOriginal(matchResults); //2. build SchemaElementMatch by info - MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); for (EmbeddingResult matchResult : matchResults) { Long elementId = Retrieval.getLongId(matchResult.getId()); SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()), SchemaElement.class); - String modelIdStr = matchResult.getMetadata().get("modelId"); - if (StringUtils.isBlank(modelIdStr)) { + if (StringUtils.isBlank(matchResult.getMetadata().get("modelId"))) { continue; } - long modelId = Long.parseLong(modelIdStr); + long modelId = Long.parseLong(matchResult.getMetadata().get("modelId")); schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId); if (schemaElement == null) { @@ -54,7 +52,7 @@ public class EmbeddingMapper extends BaseMapper { .element(schemaElement) .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) .word(matchResult.getName()) - .similarity(mapperHelper.getSimilarity(matchResult.getName(), matchResult.getDetectWord())) + .similarity(1 - matchResult.getDistance()) .detectWord(matchResult.getDetectWord()) .build(); //3. add to mapInfo diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 714f0735e..068d00e61 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -131,7 +131,6 @@ public class QueryServiceImpl implements QueryService { mapper.map(queryCtx); timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) .interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build()); - log.info("{} result:{}", mapper.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); }); //3. parser