mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
fix milvus shutdown error (#1307)
This commit is contained in:
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.service.EmbeddingService;
|
|||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||||
@@ -15,9 +16,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
|||||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import dev.langchain4j.store.embedding.filter.Filter;
|
import dev.langchain4j.store.embedding.filter.Filter;
|
||||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.MapUtils;
|
import org.apache.commons.collections.MapUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -25,31 +26,27 @@ import java.util.Comparator;
|
|||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@RequiredArgsConstructor
|
||||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||||
|
private static final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new ConcurrentHashMap<>();
|
||||||
@Autowired
|
private final EmbeddingStoreFactory embeddingStoreFactory;
|
||||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
private final EmbeddingModel embeddingModel;
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private EmbeddingModel embeddingModel;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
|
||||||
for (TextSegment query : queries) {
|
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
|
||||||
String question = query.text();
|
try {
|
||||||
try {
|
Response<List<Embedding>> embedAll = embeddingModel.embedAll(queries);
|
||||||
Embedding embedding = embeddingModel.embed(question).content();
|
embeddingStore.addAll(embedAll.content(), queries);
|
||||||
embeddingStore.add(embedding, query);
|
} catch (Exception e) {
|
||||||
} catch (Exception e) {
|
log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries,
|
||||||
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
|
embeddingStore.getClass().getSimpleName(), e);
|
||||||
embeddingStore.getClass().getSimpleName(), e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +58,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||||
|
|
||||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
|
||||||
|
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
|
||||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||||
for (String queryText : queryTextsList) {
|
for (String queryText : queryTextsList) {
|
||||||
@@ -70,7 +68,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||||
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
||||||
|
|
||||||
EmbeddingSearchResult result = embeddingStore.search(request);
|
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||||
|
|
||||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||||
@@ -83,8 +81,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
||||||
retrieval.setQuery(embedded.text());
|
retrieval.setQuery(embedded.text());
|
||||||
Map<String, Object> metadata = new HashMap<>();
|
Map<String, Object> metadata = new HashMap<>();
|
||||||
if (Objects.nonNull(embedded)
|
if (MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||||
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
|
||||||
metadata.putAll(embedded.metadata().toMap());
|
metadata.putAll(embedded.metadata().toMap());
|
||||||
}
|
}
|
||||||
retrieval.setMetadata(metadata);
|
retrieval.setMetadata(metadata);
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
package dev.langchain4j.milvus.spring;
|
package dev.langchain4j.milvus.spring;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||||
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
||||||
|
|
||||||
public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||||
private Properties properties;
|
private final Properties properties;
|
||||||
|
|
||||||
public MilvusEmbeddingStoreFactory(Properties properties) {
|
public MilvusEmbeddingStoreFactory(Properties properties) {
|
||||||
this.properties = properties;
|
this.properties = properties;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EmbeddingStore create(String collectionName) {
|
public EmbeddingStore<TextSegment> create(String collectionName) {
|
||||||
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
|
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
|
||||||
return MilvusEmbeddingStore.builder()
|
return MilvusEmbeddingStore.builder()
|
||||||
.host(embeddingStore.getHost())
|
.host(embeddingStore.getHost())
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package dev.langchain4j.store.embedding;
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
|
||||||
public interface EmbeddingStoreFactory {
|
public interface EmbeddingStoreFactory {
|
||||||
|
|
||||||
EmbeddingStore create(String collectionName);
|
EmbeddingStore<TextSegment> create(String collectionName);
|
||||||
}
|
}
|
||||||
@@ -133,8 +133,7 @@ public class TagMetaServiceImpl implements TagMetaService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<TagResp> getTags(TagFilter tagFilter) {
|
public List<TagResp> getTags(TagFilter tagFilter) {
|
||||||
List<TagResp> tagRespList = tagRepository.queryTagRespList(tagFilter);
|
return tagRepository.queryTagRespList(tagFilter);
|
||||||
return tagRespList;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
Reference in New Issue
Block a user