From 7bb0f84bc3aeed7c0201138d3e80ceb66ee7e3a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8C=E6=AF=9B?= <1402564807@qq.com> Date: Mon, 1 Jul 2024 14:55:40 +0800 Subject: [PATCH] fix milvus shutdown error (#1307) --- .../service/impl/EmbeddingServiceImpl.java | 41 +++++++++---------- .../spring/MilvusEmbeddingStoreFactory.java | 5 ++- .../embedding/EmbeddingStoreFactory.java | 4 +- .../web/service/impl/TagMetaServiceImpl.java | 3 +- 4 files changed, 26 insertions(+), 27 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 3e19db67d..28d231863 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -4,6 +4,7 @@ import com.tencent.supersonic.common.service.EmbeddingService; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchResult; @@ -15,9 +16,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.TextSegmentConvert; import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.MapUtils; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.ArrayList; @@ -25,31 +26,27 @@ import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @Service @Slf4j +@RequiredArgsConstructor public class EmbeddingServiceImpl implements EmbeddingService { - - @Autowired - private EmbeddingStoreFactory embeddingStoreFactory; - - @Autowired - private EmbeddingModel embeddingModel; + private static final Map> embeddingStoreMap = new ConcurrentHashMap<>(); + private final EmbeddingStoreFactory embeddingStoreFactory; + private final EmbeddingModel embeddingModel; @Override public void addQuery(String collectionName, List queries) { - EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); - for (TextSegment query : queries) { - String question = query.text(); - try { - Embedding embedding = embeddingModel.embed(question).content(); - embeddingStore.add(embedding, query); - } catch (Exception e) { - log.error("embeddingModel embed error question: {}, embeddingStore: {}", question, - embeddingStore.getClass().getSimpleName(), e); - } + EmbeddingStore embeddingStore = embeddingStoreMap + .computeIfAbsent(collectionName, embeddingStoreFactory::create); + try { + Response> embedAll = embeddingModel.embedAll(queries); + embeddingStore.addAll(embedAll.content(), queries); + } catch (Exception e) { + log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries, + embeddingStore.getClass().getSimpleName(), e); } } @@ -61,7 +58,8 @@ public class EmbeddingServiceImpl implements EmbeddingService { public List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { List results = new ArrayList<>(); - EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); + EmbeddingStore embeddingStore = embeddingStoreMap + .computeIfAbsent(collectionName, embeddingStoreFactory::create); List queryTextsList = retrieveQuery.getQueryTextsList(); Map filterCondition = retrieveQuery.getFilterCondition(); for (String queryText : queryTextsList) { @@ -70,7 +68,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() .queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); - EmbeddingSearchResult result = embeddingStore.search(request); + EmbeddingSearchResult result = embeddingStore.search(request); List> relevant = result.matches(); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); @@ -83,8 +81,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { retrieval.setId(TextSegmentConvert.getQueryId(embedded)); retrieval.setQuery(embedded.text()); Map metadata = new HashMap<>(); - if (Objects.nonNull(embedded) - && MapUtils.isNotEmpty(embedded.metadata().toMap())) { + if (MapUtils.isNotEmpty(embedded.metadata().toMap())) { metadata.putAll(embedded.metadata().toMap()); } retrieval.setMetadata(metadata); diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java index 1516439c6..0ca780057 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java @@ -1,18 +1,19 @@ package dev.langchain4j.milvus.spring; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory { - private Properties properties; + private final Properties properties; public MilvusEmbeddingStoreFactory(Properties properties) { this.properties = properties; } @Override - public EmbeddingStore create(String collectionName) { + public EmbeddingStore create(String collectionName) { EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore(); return MilvusEmbeddingStore.builder() .host(embeddingStore.getHost()) diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactory.java index 32145b7fd..d818ea80d 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactory.java @@ -1,6 +1,8 @@ package dev.langchain4j.store.embedding; +import dev.langchain4j.data.segment.TextSegment; + public interface EmbeddingStoreFactory { - EmbeddingStore create(String collectionName); + EmbeddingStore create(String collectionName); } \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/service/impl/TagMetaServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/service/impl/TagMetaServiceImpl.java index 5452a5f8d..f25e28bc3 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/service/impl/TagMetaServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/web/service/impl/TagMetaServiceImpl.java @@ -133,8 +133,7 @@ public class TagMetaServiceImpl implements TagMetaService { @Override public List getTags(TagFilter tagFilter) { - List tagRespList = tagRepository.queryTagRespList(tagFilter); - return tagRespList; + return tagRepository.queryTagRespList(tagFilter); } @Override