mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
fix milvus shutdown error (#1307)
This commit is contained in:
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
@@ -15,9 +16,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -25,31 +26,27 @@ import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
|
||||
@Autowired
|
||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingModel embeddingModel;
|
||||
private static final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new ConcurrentHashMap<>();
|
||||
private final EmbeddingStoreFactory embeddingStoreFactory;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
@Override
|
||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
for (TextSegment query : queries) {
|
||||
String question = query.text();
|
||||
try {
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
embeddingStore.add(embedding, query);
|
||||
} catch (Exception e) {
|
||||
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
|
||||
embeddingStore.getClass().getSimpleName(), e);
|
||||
}
|
||||
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
|
||||
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
|
||||
try {
|
||||
Response<List<Embedding>> embedAll = embeddingModel.embedAll(queries);
|
||||
embeddingStore.addAll(embedAll.content(), queries);
|
||||
} catch (Exception e) {
|
||||
log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries,
|
||||
embeddingStore.getClass().getSimpleName(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +58,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
|
||||
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
for (String queryText : queryTextsList) {
|
||||
@@ -70,7 +68,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
||||
|
||||
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
|
||||
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||
|
||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||
@@ -83,8 +81,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
||||
retrieval.setQuery(embedded.text());
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
if (Objects.nonNull(embedded)
|
||||
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||
if (MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||
metadata.putAll(embedded.metadata().toMap());
|
||||
}
|
||||
retrieval.setMetadata(metadata);
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
package dev.langchain4j.milvus.spring;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
||||
|
||||
public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
private Properties properties;
|
||||
private final Properties properties;
|
||||
|
||||
public MilvusEmbeddingStoreFactory(Properties properties) {
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingStore create(String collectionName) {
|
||||
public EmbeddingStore<TextSegment> create(String collectionName) {
|
||||
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
|
||||
return MilvusEmbeddingStore.builder()
|
||||
.host(embeddingStore.getHost())
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
|
||||
public interface EmbeddingStoreFactory {
|
||||
|
||||
EmbeddingStore create(String collectionName);
|
||||
EmbeddingStore<TextSegment> create(String collectionName);
|
||||
}
|
||||
@@ -133,8 +133,7 @@ public class TagMetaServiceImpl implements TagMetaService {
|
||||
|
||||
@Override
|
||||
public List<TagResp> getTags(TagFilter tagFilter) {
|
||||
List<TagResp> tagRespList = tagRepository.queryTagRespList(tagFilter);
|
||||
return tagRespList;
|
||||
return tagRepository.queryTagRespList(tagFilter);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user