fix milvus shutdown error (#1307)

This commit is contained in:
二毛
2024-07-01 14:55:40 +08:00
committed by GitHub
parent d2e9d1bf85
commit 7bb0f84bc3
4 changed files with 26 additions and 27 deletions

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
@@ -15,9 +16,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
@@ -25,31 +26,27 @@ import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
@Service
@Slf4j
@RequiredArgsConstructor
public class EmbeddingServiceImpl implements EmbeddingService {
@Autowired
private EmbeddingStoreFactory embeddingStoreFactory;
@Autowired
private EmbeddingModel embeddingModel;
private static final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new ConcurrentHashMap<>();
private final EmbeddingStoreFactory embeddingStoreFactory;
private final EmbeddingModel embeddingModel;
@Override
public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
for (TextSegment query : queries) {
String question = query.text();
try {
Embedding embedding = embeddingModel.embed(question).content();
embeddingStore.add(embedding, query);
} catch (Exception e) {
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
embeddingStore.getClass().getSimpleName(), e);
}
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
try {
Response<List<Embedding>> embedAll = embeddingModel.embedAll(queries);
embeddingStore.addAll(embedAll.content(), queries);
} catch (Exception e) {
log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries,
embeddingStore.getClass().getSimpleName(), e);
}
}
@@ -61,7 +58,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
List<RetrieveQueryResult> results = new ArrayList<>();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) {
@@ -70,7 +68,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
EmbeddingSearchResult result = embeddingStore.search(request);
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
@@ -83,8 +81,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text());
Map<String, Object> metadata = new HashMap<>();
if (Objects.nonNull(embedded)
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
if (MapUtils.isNotEmpty(embedded.metadata().toMap())) {
metadata.putAll(embedded.metadata().toMap());
}
retrieval.setMetadata(metadata);

View File

@@ -1,18 +1,19 @@
package dev.langchain4j.milvus.spring;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
private Properties properties;
private final Properties properties;
public MilvusEmbeddingStoreFactory(Properties properties) {
this.properties = properties;
}
@Override
public EmbeddingStore create(String collectionName) {
public EmbeddingStore<TextSegment> create(String collectionName) {
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
return MilvusEmbeddingStore.builder()
.host(embeddingStore.getHost())

View File

@@ -1,6 +1,8 @@
package dev.langchain4j.store.embedding;
import dev.langchain4j.data.segment.TextSegment;
public interface EmbeddingStoreFactory {
EmbeddingStore create(String collectionName);
EmbeddingStore<TextSegment> create(String collectionName);
}

View File

@@ -133,8 +133,7 @@ public class TagMetaServiceImpl implements TagMetaService {
@Override
public List<TagResp> getTags(TagFilter tagFilter) {
List<TagResp> tagRespList = tagRepository.queryTagRespList(tagFilter);
return tagRespList;
return tagRepository.queryTagRespList(tagFilter);
}
@Override