(improvement)(chat) Fix the error in Milvus query and add the option to create EmbeddingStore based on caching mode (#1310)

This commit is contained in:
lexluo09
2024-07-01 16:29:43 +08:00
committed by GitHub
parent 37d08007c4
commit 7773442fbf
11 changed files with 489 additions and 104 deletions

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
@@ -16,9 +15,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
@@ -26,27 +25,31 @@ import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Objects;
import java.util.stream.Collectors;
@Service
@Slf4j
@RequiredArgsConstructor
public class EmbeddingServiceImpl implements EmbeddingService {
private static final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new ConcurrentHashMap<>();
private final EmbeddingStoreFactory embeddingStoreFactory;
private final EmbeddingModel embeddingModel;
@Autowired
private EmbeddingStoreFactory embeddingStoreFactory;
@Autowired
private EmbeddingModel embeddingModel;
@Override
public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
try {
Response<List<Embedding>> embedAll = embeddingModel.embedAll(queries);
embeddingStore.addAll(embedAll.content(), queries);
} catch (Exception e) {
log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries,
embeddingStore.getClass().getSimpleName(), e);
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
for (TextSegment query : queries) {
String question = query.text();
try {
Embedding embedding = embeddingModel.embed(question).content();
embeddingStore.add(embedding, query);
} catch (Exception e) {
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
embeddingStore.getClass().getSimpleName(), e);
}
}
}
@@ -58,8 +61,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
List<RetrieveQueryResult> results = new ArrayList<>();
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) {
@@ -68,7 +70,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
EmbeddingSearchResult result = embeddingStore.search(request);
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
@@ -81,7 +83,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text());
Map<String, Object> metadata = new HashMap<>();
if (MapUtils.isNotEmpty(embedded.metadata().toMap())) {
if (Objects.nonNull(embedded)
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
metadata.putAll(embedded.metadata().toMap());
}
retrieval.setMetadata(metadata);