fix milvus shutdown error (#1307)

This commit is contained in:
二毛
2024-07-01 14:55:40 +08:00
committed by GitHub
parent d2e9d1bf85
commit 7bb0f84bc3
4 changed files with 26 additions and 27 deletions

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingSearchResult;
@@ -15,9 +16,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert; import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
@@ -25,33 +26,29 @@ import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Service @Service
@Slf4j @Slf4j
@RequiredArgsConstructor
public class EmbeddingServiceImpl implements EmbeddingService { public class EmbeddingServiceImpl implements EmbeddingService {
private static final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new ConcurrentHashMap<>();
@Autowired private final EmbeddingStoreFactory embeddingStoreFactory;
private EmbeddingStoreFactory embeddingStoreFactory; private final EmbeddingModel embeddingModel;
@Autowired
private EmbeddingModel embeddingModel;
@Override @Override
public void addQuery(String collectionName, List<TextSegment> queries) { public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
for (TextSegment query : queries) { .computeIfAbsent(collectionName, embeddingStoreFactory::create);
String question = query.text();
try { try {
Embedding embedding = embeddingModel.embed(question).content(); Response<List<Embedding>> embedAll = embeddingModel.embedAll(queries);
embeddingStore.add(embedding, query); embeddingStore.addAll(embedAll.content(), queries);
} catch (Exception e) { } catch (Exception e) {
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question, log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries,
embeddingStore.getClass().getSimpleName(), e); embeddingStore.getClass().getSimpleName(), e);
} }
} }
}
@Override @Override
public void deleteQuery(String collectionName, List<TextSegment> queries) { public void deleteQuery(String collectionName, List<TextSegment> queries) {
@@ -61,7 +58,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
List<RetrieveQueryResult> results = new ArrayList<>(); List<RetrieveQueryResult> results = new ArrayList<>();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
List<String> queryTextsList = retrieveQuery.getQueryTextsList(); List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition(); Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) { for (String queryText : queryTextsList) {
@@ -70,7 +68,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); .queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
EmbeddingSearchResult result = embeddingStore.search(request); EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
List<EmbeddingMatch<TextSegment>> relevant = result.matches(); List<EmbeddingMatch<TextSegment>> relevant = result.matches();
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
@@ -83,8 +81,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
retrieval.setId(TextSegmentConvert.getQueryId(embedded)); retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text()); retrieval.setQuery(embedded.text());
Map<String, Object> metadata = new HashMap<>(); Map<String, Object> metadata = new HashMap<>();
if (Objects.nonNull(embedded) if (MapUtils.isNotEmpty(embedded.metadata().toMap())) {
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
metadata.putAll(embedded.metadata().toMap()); metadata.putAll(embedded.metadata().toMap());
} }
retrieval.setMetadata(metadata); retrieval.setMetadata(metadata);

View File

@@ -1,18 +1,19 @@
package dev.langchain4j.milvus.spring; package dev.langchain4j.milvus.spring;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory { public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
private Properties properties; private final Properties properties;
public MilvusEmbeddingStoreFactory(Properties properties) { public MilvusEmbeddingStoreFactory(Properties properties) {
this.properties = properties; this.properties = properties;
} }
@Override @Override
public EmbeddingStore create(String collectionName) { public EmbeddingStore<TextSegment> create(String collectionName) {
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore(); EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
return MilvusEmbeddingStore.builder() return MilvusEmbeddingStore.builder()
.host(embeddingStore.getHost()) .host(embeddingStore.getHost())

View File

@@ -1,6 +1,8 @@
package dev.langchain4j.store.embedding; package dev.langchain4j.store.embedding;
import dev.langchain4j.data.segment.TextSegment;
public interface EmbeddingStoreFactory { public interface EmbeddingStoreFactory {
EmbeddingStore create(String collectionName); EmbeddingStore<TextSegment> create(String collectionName);
} }

View File

@@ -133,8 +133,7 @@ public class TagMetaServiceImpl implements TagMetaService {
@Override @Override
public List<TagResp> getTags(TagFilter tagFilter) { public List<TagResp> getTags(TagFilter tagFilter) {
List<TagResp> tagRespList = tagRepository.queryTagRespList(tagFilter); return tagRepository.queryTagRespList(tagFilter);
return tagRespList;
} }
@Override @Override