diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index b28e6336a..1cdbcd6f3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -65,12 +65,16 @@ public class MemoryServiceImpl implements MemoryService, CommandLineRunner { ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId()); boolean hadEnabled = MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim()); - if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) { + + if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus())) { + // Update the latest SQL/Schema to vector DB once memory is enabled + chatMemoryDO.setS2sql(chatMemoryUpdateReq.getS2sql()); + chatMemoryDO.setDbSchema(chatMemoryUpdateReq.getDbSchema()); enableMemory(chatMemoryDO); - } else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) { + } else if ((MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus())||MemoryStatus.PENDING.equals(chatMemoryUpdateReq.getStatus())) && hadEnabled) { + // Remove from vector DB when transitioning: launched→disabled OR enabled→pending disableMemory(chatMemoryDO); } - LambdaUpdateWrapper updateWrapper = new LambdaUpdateWrapper<>(); updateWrapper.eq(ChatMemoryDO::getId, chatMemoryDO.getId()); if (Objects.nonNull(chatMemoryUpdateReq.getStatus())) { diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 9bd6563aa..b5694e26a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -49,11 +49,10 @@ public class EmbeddingServiceImpl implements EmbeddingService { try { EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); Embedding embedding = embeddingModel.embed(question).content(); - boolean existSegment = - existSegment(collectionName, embeddingStore, query, embedding); - if (existSegment) { - continue; - } + MetadataFilterBuilder filterBuilder = + new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID); + Filter filter = filterBuilder.isEqualTo(TextSegmentConvert.getQueryId(query)); + embeddingStore.removeAll(filter); embeddingStore.add(embedding, query); cache.put(TextSegmentConvert.getQueryId(query), true); } catch (Exception e) {