mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(improvement)(chat) Fix the error in Milvus query and add the option to create EmbeddingStore based on caching mode (#1310)
This commit is contained in:
@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.service.EmbeddingService;
|
|||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||||
@@ -16,9 +15,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
|||||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import dev.langchain4j.store.embedding.filter.Filter;
|
import dev.langchain4j.store.embedding.filter.Filter;
|
||||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.MapUtils;
|
import org.apache.commons.collections.MapUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -26,27 +25,31 @@ import java.util.Comparator;
|
|||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.Objects;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||||
private static final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new ConcurrentHashMap<>();
|
|
||||||
private final EmbeddingStoreFactory embeddingStoreFactory;
|
@Autowired
|
||||||
private final EmbeddingModel embeddingModel;
|
private EmbeddingStoreFactory embeddingStoreFactory;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private EmbeddingModel embeddingModel;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||||
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
|
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||||
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
|
for (TextSegment query : queries) {
|
||||||
try {
|
String question = query.text();
|
||||||
Response<List<Embedding>> embedAll = embeddingModel.embedAll(queries);
|
try {
|
||||||
embeddingStore.addAll(embedAll.content(), queries);
|
Embedding embedding = embeddingModel.embed(question).content();
|
||||||
} catch (Exception e) {
|
embeddingStore.add(embedding, query);
|
||||||
log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries,
|
} catch (Exception e) {
|
||||||
embeddingStore.getClass().getSimpleName(), e);
|
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
|
||||||
|
embeddingStore.getClass().getSimpleName(), e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,8 +61,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||||
|
|
||||||
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
|
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||||
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
|
|
||||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||||
for (String queryText : queryTextsList) {
|
for (String queryText : queryTextsList) {
|
||||||
@@ -68,7 +70,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||||
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
||||||
|
|
||||||
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
|
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||||
|
|
||||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||||
@@ -81,7 +83,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
||||||
retrieval.setQuery(embedded.text());
|
retrieval.setQuery(embedded.text());
|
||||||
Map<String, Object> metadata = new HashMap<>();
|
Map<String, Object> metadata = new HashMap<>();
|
||||||
if (MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
if (Objects.nonNull(embedded)
|
||||||
|
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||||
metadata.putAll(embedded.metadata().toMap());
|
metadata.putAll(embedded.metadata().toMap());
|
||||||
}
|
}
|
||||||
retrieval.setMetadata(metadata);
|
retrieval.setMetadata(metadata);
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
package dev.langchain4j.chroma.spring;
|
package dev.langchain4j.chroma.spring;
|
||||||
|
|
||||||
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
|
||||||
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
|
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||||
|
|
||||||
private Properties properties;
|
private Properties properties;
|
||||||
|
|
||||||
@@ -18,31 +15,12 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EmbeddingStore create(String collectionName) {
|
public EmbeddingStore createEmbeddingStore(String collectionName) {
|
||||||
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
||||||
EmbeddingStore embeddingStore = null;
|
return ChromaEmbeddingStore.builder()
|
||||||
try {
|
.baseUrl(storeProperties.getBaseUrl())
|
||||||
embeddingStore = ChromaEmbeddingStore.builder()
|
.collectionName(collectionName)
|
||||||
.baseUrl(storeProperties.getBaseUrl())
|
.timeout(storeProperties.getTimeout())
|
||||||
.collectionName(collectionName)
|
.build();
|
||||||
.timeout(storeProperties.getTimeout())
|
|
||||||
.build();
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.debug("Failed to create ChromaEmbeddingStore,collectionName:{}"
|
|
||||||
+ ", fallback to the default InMemoryEmbeddingStore method。",
|
|
||||||
collectionName, e.getMessage());
|
|
||||||
}
|
|
||||||
if (Objects.isNull(embeddingStore)) {
|
|
||||||
embeddingStore = createInMemoryEmbeddingStore(collectionName);
|
|
||||||
}
|
|
||||||
return embeddingStore;
|
|
||||||
}
|
|
||||||
|
|
||||||
private EmbeddingStore createInMemoryEmbeddingStore(String collectionName) {
|
|
||||||
dev.langchain4j.inmemory.spring.Properties properties = new dev.langchain4j.inmemory.spring.Properties();
|
|
||||||
dev.langchain4j.inmemory.spring.EmbeddingStoreProperties embeddingStoreProperties =
|
|
||||||
new dev.langchain4j.inmemory.spring.EmbeddingStoreProperties();
|
|
||||||
properties.setEmbeddingStore(embeddingStoreProperties);
|
|
||||||
return new InMemoryEmbeddingStoreFactory(properties).create(collectionName);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3,8 +3,8 @@ package dev.langchain4j.inmemory.spring;
|
|||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
|
||||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.MapUtils;
|
import org.apache.commons.collections4.MapUtils;
|
||||||
@@ -15,15 +15,12 @@ import java.nio.file.Path;
|
|||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.CopyOnWriteArraySet;
|
import java.util.concurrent.CopyOnWriteArraySet;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||||
|
|
||||||
public static final String PERSISTENT_FILE_PRE = "InMemory.";
|
public static final String PERSISTENT_FILE_PRE = "InMemory.";
|
||||||
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
|
|
||||||
new ConcurrentHashMap<>();
|
|
||||||
private Properties properties;
|
private Properties properties;
|
||||||
|
|
||||||
|
|
||||||
@@ -32,18 +29,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public synchronized EmbeddingStore create(String collectionName) {
|
public synchronized EmbeddingStore createEmbeddingStore(String collectionName) {
|
||||||
InMemoryEmbeddingStore<TextSegment> embeddingStore = collectionNameToStore.get(collectionName);
|
InMemoryEmbeddingStore<TextSegment> embeddingStore = reloadFromPersistFile(collectionName);
|
||||||
if (Objects.nonNull(embeddingStore)) {
|
|
||||||
return embeddingStore;
|
|
||||||
}
|
|
||||||
embeddingStore = reloadFromPersistFile(collectionName);
|
|
||||||
if (Objects.isNull(embeddingStore)) {
|
if (Objects.isNull(embeddingStore)) {
|
||||||
embeddingStore = new InMemoryEmbeddingStore();
|
embeddingStore = new InMemoryEmbeddingStore();
|
||||||
}
|
}
|
||||||
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
|
|
||||||
return embeddingStore;
|
return embeddingStore;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private InMemoryEmbeddingStore<TextSegment> reloadFromPersistFile(String collectionName) {
|
private InMemoryEmbeddingStore<TextSegment> reloadFromPersistFile(String collectionName) {
|
||||||
@@ -67,10 +58,10 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public synchronized void persistFile() {
|
public synchronized void persistFile() {
|
||||||
if (MapUtils.isEmpty(collectionNameToStore)) {
|
if (MapUtils.isEmpty(super.collectionNameToStore)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (Map.Entry<String, InMemoryEmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
|
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
|
||||||
Path filePath = getPersistPath(entry.getKey());
|
Path filePath = getPersistPath(entry.getKey());
|
||||||
if (Objects.isNull(filePath)) {
|
if (Objects.isNull(filePath)) {
|
||||||
continue;
|
continue;
|
||||||
@@ -80,7 +71,11 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
|||||||
if (!Files.exists(directoryPath)) {
|
if (!Files.exists(directoryPath)) {
|
||||||
Files.createDirectories(directoryPath);
|
Files.createDirectories(directoryPath);
|
||||||
}
|
}
|
||||||
entry.getValue().serializeToFile(filePath);
|
if (entry.getValue() instanceof InMemoryEmbeddingStore) {
|
||||||
|
InMemoryEmbeddingStore<TextSegment> inMemoryEmbeddingStore =
|
||||||
|
(InMemoryEmbeddingStore) entry.getValue();
|
||||||
|
inMemoryEmbeddingStore.serializeToFile(filePath);
|
||||||
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("persistFile error, persistFile:" + filePath, e);
|
log.error("persistFile error, persistFile:" + filePath, e);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,4 +23,5 @@ class EmbeddingStoreProperties {
|
|||||||
private ConsistencyLevelEnum consistencyLevel;
|
private ConsistencyLevelEnum consistencyLevel;
|
||||||
private Boolean retrieveEmbeddingsOnSearch;
|
private Boolean retrieveEmbeddingsOnSearch;
|
||||||
private String databaseName;
|
private String databaseName;
|
||||||
|
private Boolean autoFlushOnInsert;
|
||||||
}
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
package dev.langchain4j.milvus.spring;
|
package dev.langchain4j.milvus.spring;
|
||||||
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
|
||||||
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
||||||
|
|
||||||
public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||||
private final Properties properties;
|
private final Properties properties;
|
||||||
|
|
||||||
public MilvusEmbeddingStoreFactory(Properties properties) {
|
public MilvusEmbeddingStoreFactory(Properties properties) {
|
||||||
@@ -13,22 +13,23 @@ public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EmbeddingStore<TextSegment> create(String collectionName) {
|
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
||||||
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
|
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
||||||
return MilvusEmbeddingStore.builder()
|
return MilvusEmbeddingStore.builder()
|
||||||
.host(embeddingStore.getHost())
|
.host(storeProperties.getHost())
|
||||||
.port(embeddingStore.getPort())
|
.port(storeProperties.getPort())
|
||||||
.collectionName(collectionName)
|
.collectionName(collectionName)
|
||||||
.dimension(embeddingStore.getDimension())
|
.dimension(storeProperties.getDimension())
|
||||||
.indexType(embeddingStore.getIndexType())
|
.indexType(storeProperties.getIndexType())
|
||||||
.metricType(embeddingStore.getMetricType())
|
.metricType(storeProperties.getMetricType())
|
||||||
.uri(embeddingStore.getUri())
|
.uri(storeProperties.getUri())
|
||||||
.token(embeddingStore.getToken())
|
.token(storeProperties.getToken())
|
||||||
.username(embeddingStore.getUsername())
|
.username(storeProperties.getUsername())
|
||||||
.password(embeddingStore.getPassword())
|
.password(storeProperties.getPassword())
|
||||||
.consistencyLevel(embeddingStore.getConsistencyLevel())
|
.consistencyLevel(storeProperties.getConsistencyLevel())
|
||||||
.retrieveEmbeddingsOnSearch(embeddingStore.getRetrieveEmbeddingsOnSearch())
|
.retrieveEmbeddingsOnSearch(storeProperties.getRetrieveEmbeddingsOnSearch())
|
||||||
.databaseName(embeddingStore.getDatabaseName())
|
.autoFlushOnInsert(storeProperties.getAutoFlushOnInsert())
|
||||||
|
.databaseName(storeProperties.getDatabaseName())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||||
|
protected static final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
public EmbeddingStore<TextSegment> create(String collectionName) {
|
||||||
|
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName);
|
||||||
|
}
|
||||||
@@ -5,4 +5,5 @@ import dev.langchain4j.data.segment.TextSegment;
|
|||||||
public interface EmbeddingStoreFactory {
|
public interface EmbeddingStoreFactory {
|
||||||
|
|
||||||
EmbeddingStore<TextSegment> create(String collectionName);
|
EmbeddingStore<TextSegment> create(String collectionName);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,375 @@
|
|||||||
|
package dev.langchain4j.store.embedding.milvus;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.document.Metadata;
|
||||||
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.internal.Utils;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
import dev.langchain4j.store.embedding.filter.Filter;
|
||||||
|
import io.milvus.client.MilvusServiceClient;
|
||||||
|
import io.milvus.common.clientenum.ConsistencyLevelEnum;
|
||||||
|
import io.milvus.param.ConnectParam;
|
||||||
|
import io.milvus.param.IndexType;
|
||||||
|
import io.milvus.param.MetricType;
|
||||||
|
import io.milvus.param.dml.InsertParam;
|
||||||
|
import io.milvus.param.dml.SearchParam;
|
||||||
|
import io.milvus.response.SearchResultsWrapper;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||||
|
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.createCollection;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.createIndex;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.flush;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.hasCollection;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.insert;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.loadCollectionInMemory;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.buildSearchRequest;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.Generator.generateRandomIds;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.Mapper.toEmbeddingMatches;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.Mapper.toMetadataJsons;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.Mapper.toScalars;
|
||||||
|
import static dev.langchain4j.store.embedding.milvus.Mapper.toVectors;
|
||||||
|
import static io.milvus.common.clientenum.ConsistencyLevelEnum.EVENTUALLY;
|
||||||
|
import static io.milvus.param.IndexType.FLAT;
|
||||||
|
import static io.milvus.param.MetricType.COSINE;
|
||||||
|
import static java.util.Collections.singletonList;
|
||||||
|
import static java.util.stream.Collectors.toList;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents an <a href="https://milvus.io/">Milvus</a> index as an embedding store.
|
||||||
|
* <br>
|
||||||
|
* Supports both local and <a href="https://zilliz.com/">managed</a> Milvus instances.
|
||||||
|
* <br>
|
||||||
|
* Supports storing {@link Metadata} and filtering by it using a {@link Filter}
|
||||||
|
* (provided inside an {@link EmbeddingSearchRequest}).
|
||||||
|
*/
|
||||||
|
public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||||
|
|
||||||
|
static final String ID_FIELD_NAME = "id";
|
||||||
|
static final String TEXT_FIELD_NAME = "text";
|
||||||
|
static final String METADATA_FIELD_NAME = "metadata";
|
||||||
|
static final String VECTOR_FIELD_NAME = "vector";
|
||||||
|
|
||||||
|
private final MilvusServiceClient milvusClient;
|
||||||
|
private final String collectionName;
|
||||||
|
private final MetricType metricType;
|
||||||
|
private final ConsistencyLevelEnum consistencyLevel;
|
||||||
|
private final boolean retrieveEmbeddingsOnSearch;
|
||||||
|
|
||||||
|
private final boolean autoFlushOnInsert;
|
||||||
|
|
||||||
|
public MilvusEmbeddingStore(
|
||||||
|
String host,
|
||||||
|
Integer port,
|
||||||
|
String collectionName,
|
||||||
|
Integer dimension,
|
||||||
|
IndexType indexType,
|
||||||
|
MetricType metricType,
|
||||||
|
String uri,
|
||||||
|
String token,
|
||||||
|
String username,
|
||||||
|
String password,
|
||||||
|
ConsistencyLevelEnum consistencyLevel,
|
||||||
|
Boolean retrieveEmbeddingsOnSearch,
|
||||||
|
Boolean autoFlushOnInsert,
|
||||||
|
String databaseName
|
||||||
|
) {
|
||||||
|
ConnectParam.Builder connectBuilder = ConnectParam
|
||||||
|
.newBuilder()
|
||||||
|
.withHost(getOrDefault(host, "localhost"))
|
||||||
|
.withPort(getOrDefault(port, 19530))
|
||||||
|
.withUri(uri)
|
||||||
|
.withToken(token)
|
||||||
|
.withAuthorization(username, password);
|
||||||
|
|
||||||
|
if (databaseName != null) {
|
||||||
|
connectBuilder.withDatabaseName(databaseName);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.milvusClient = new MilvusServiceClient(connectBuilder.build());
|
||||||
|
this.collectionName = getOrDefault(collectionName, "default");
|
||||||
|
this.metricType = getOrDefault(metricType, COSINE);
|
||||||
|
this.consistencyLevel = getOrDefault(consistencyLevel, EVENTUALLY);
|
||||||
|
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
|
||||||
|
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
|
||||||
|
if (!hasCollection(milvusClient, this.collectionName)) {
|
||||||
|
createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
|
||||||
|
createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType);
|
||||||
|
}
|
||||||
|
|
||||||
|
loadCollectionInMemory(milvusClient, collectionName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void dropCollection(String collectionName) {
|
||||||
|
CollectionOperationsExecutor.dropCollection(milvusClient, collectionName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String add(Embedding embedding) {
|
||||||
|
String id = Utils.randomUUID();
|
||||||
|
add(id, embedding);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void add(String id, Embedding embedding) {
|
||||||
|
addInternal(id, embedding, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String add(Embedding embedding, TextSegment textSegment) {
|
||||||
|
String id = Utils.randomUUID();
|
||||||
|
addInternal(id, embedding, textSegment);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> addAll(List<Embedding> embeddings) {
|
||||||
|
List<String> ids = generateRandomIds(embeddings.size());
|
||||||
|
addAllInternal(ids, embeddings, null);
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||||
|
List<String> ids = generateRandomIds(embeddings.size());
|
||||||
|
addAllInternal(ids, embeddings, embedded);
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
|
||||||
|
|
||||||
|
SearchParam searchParam = buildSearchRequest(
|
||||||
|
collectionName,
|
||||||
|
embeddingSearchRequest.queryEmbedding().vectorAsList(),
|
||||||
|
embeddingSearchRequest.filter(),
|
||||||
|
embeddingSearchRequest.maxResults(),
|
||||||
|
metricType,
|
||||||
|
consistencyLevel
|
||||||
|
);
|
||||||
|
|
||||||
|
SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam);
|
||||||
|
|
||||||
|
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(
|
||||||
|
milvusClient,
|
||||||
|
resultsWrapper,
|
||||||
|
collectionName,
|
||||||
|
consistencyLevel,
|
||||||
|
retrieveEmbeddingsOnSearch
|
||||||
|
);
|
||||||
|
|
||||||
|
List<EmbeddingMatch<TextSegment>> result = matches.stream()
|
||||||
|
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
|
||||||
|
.collect(toList());
|
||||||
|
|
||||||
|
return new EmbeddingSearchResult<>(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
|
||||||
|
addAllInternal(
|
||||||
|
singletonList(id),
|
||||||
|
singletonList(embedding),
|
||||||
|
textSegment == null ? null : singletonList(textSegment)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
|
||||||
|
List<InsertParam.Field> fields = new ArrayList<>();
|
||||||
|
fields.add(new InsertParam.Field(ID_FIELD_NAME, ids));
|
||||||
|
fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size())));
|
||||||
|
fields.add(new InsertParam.Field(METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
|
||||||
|
fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings)));
|
||||||
|
|
||||||
|
insert(milvusClient, collectionName, fields);
|
||||||
|
if (autoFlushOnInsert) {
|
||||||
|
flush(this.milvusClient, this.collectionName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
private String host;
|
||||||
|
private Integer port;
|
||||||
|
private String collectionName;
|
||||||
|
private Integer dimension;
|
||||||
|
private IndexType indexType;
|
||||||
|
private MetricType metricType;
|
||||||
|
private String uri;
|
||||||
|
private String token;
|
||||||
|
private String username;
|
||||||
|
private String password;
|
||||||
|
private ConsistencyLevelEnum consistencyLevel;
|
||||||
|
private Boolean retrieveEmbeddingsOnSearch;
|
||||||
|
private Boolean autoFlushOnInsert;
|
||||||
|
private String databaseName;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param host The host of the self-managed Milvus instance.
|
||||||
|
* Default value: "localhost".
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder host(String host) {
|
||||||
|
this.host = host;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param port The port of the self-managed Milvus instance.
|
||||||
|
* Default value: 19530.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder port(Integer port) {
|
||||||
|
this.port = port;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param collectionName The name of the Milvus collection.
|
||||||
|
* If there is no such collection yet, it will be created automatically.
|
||||||
|
* Default value: "default".
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder collectionName(String collectionName) {
|
||||||
|
this.collectionName = collectionName;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param dimension The dimension of the embedding vector. (e.g. 384)
|
||||||
|
* Mandatory if a new collection should be created.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder dimension(Integer dimension) {
|
||||||
|
this.dimension = dimension;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param indexType The type of the index.
|
||||||
|
* Default value: FLAT.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder indexType(IndexType indexType) {
|
||||||
|
this.indexType = indexType;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param metricType The type of the metric used for similarity search.
|
||||||
|
* Default value: COSINE.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder metricType(MetricType metricType) {
|
||||||
|
this.metricType = metricType;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param uri The URI of the managed Milvus instance. (e.g. "https://xxx.api.gcp-us-west1.zillizcloud.com")
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder uri(String uri) {
|
||||||
|
this.uri = uri;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param token The token (API key) of the managed Milvus instance.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder token(String token) {
|
||||||
|
this.token = token;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param username The username. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder username(String username) {
|
||||||
|
this.username = username;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param password The password. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder password(String password) {
|
||||||
|
this.password = password;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param consistencyLevel The consistency level used by Milvus.
|
||||||
|
* Default value: EVENTUALLY.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder consistencyLevel(ConsistencyLevelEnum consistencyLevel) {
|
||||||
|
this.consistencyLevel = consistencyLevel;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling findRelevant()),
|
||||||
|
* the embedding itself is not retrieved.
|
||||||
|
* To retrieve the embedding, an additional query is required.
|
||||||
|
* Setting this parameter to "true" will ensure that embedding is retrieved.
|
||||||
|
* Be aware that this will impact the performance of the search.
|
||||||
|
* Default value: false.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) {
|
||||||
|
this.retrieveEmbeddingsOnSearch = retrieveEmbeddingsOnSearch;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param autoFlushOnInsert Whether to automatically flush after each insert
|
||||||
|
* ({@code add(...)} or {@code addAll(...)} methods).
|
||||||
|
* Default value: false.
|
||||||
|
* More info can be found
|
||||||
|
* <a href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) {
|
||||||
|
this.autoFlushOnInsert = autoFlushOnInsert;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param databaseName Milvus name of database.
|
||||||
|
* Default value: null. In this case default Milvus database name will be used.
|
||||||
|
* @return builder
|
||||||
|
*/
|
||||||
|
public Builder databaseName(String databaseName) {
|
||||||
|
this.databaseName = databaseName;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MilvusEmbeddingStore build() {
|
||||||
|
return new MilvusEmbeddingStore(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
collectionName,
|
||||||
|
dimension,
|
||||||
|
indexType,
|
||||||
|
metricType,
|
||||||
|
uri,
|
||||||
|
token,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
consistencyLevel,
|
||||||
|
retrieveEmbeddingsOnSearch,
|
||||||
|
autoFlushOnInsert,
|
||||||
|
databaseName
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,19 +24,19 @@ langchain4j:
|
|||||||
embedding-model:
|
embedding-model:
|
||||||
model-name: bge-small-zh
|
model-name: bge-small-zh
|
||||||
|
|
||||||
embedding-store:
|
# embedding-store:
|
||||||
persist-path: /tmp
|
# persist-path: /tmp
|
||||||
|
|
||||||
# chroma:
|
# chroma:
|
||||||
# embedding-store:
|
# embedding-store:
|
||||||
# baseUrl: http://0.0.0.0:8000
|
# baseUrl: http://0.0.0.0:8000
|
||||||
# timeout: 120s
|
# timeout: 120s
|
||||||
|
|
||||||
# milvus:
|
milvus:
|
||||||
# embedding-store:
|
embedding-store:
|
||||||
# host: localhost
|
host: localhost
|
||||||
# port: 2379
|
port: 2379
|
||||||
# uri: http://0.0.0.0:19530
|
uri: http://0.0.0.0:19530
|
||||||
# token: demo
|
token: demo
|
||||||
# dimension: 512
|
dimension: 512
|
||||||
# timeout: 120s
|
timeout: 120s
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
spring:
|
||||||
|
datasource:
|
||||||
|
driver-class-name: org.h2.Driver
|
||||||
|
schema: classpath:db/schema-h2.sql
|
||||||
|
data: classpath:db/data-h2.sql
|
||||||
|
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false
|
||||||
|
username: root
|
||||||
|
password: semantic
|
||||||
|
h2:
|
||||||
|
console:
|
||||||
|
path: /h2-console/semantic
|
||||||
|
enabled: true
|
||||||
|
config:
|
||||||
|
import:
|
||||||
|
- classpath:langchain4j-local.yaml
|
||||||
@@ -13,7 +13,6 @@ spring:
|
|||||||
config:
|
config:
|
||||||
import:
|
import:
|
||||||
- classpath:s2-config.yaml
|
- classpath:s2-config.yaml
|
||||||
- classpath:langchain4j-local.yaml
|
|
||||||
autoconfigure:
|
autoconfigure:
|
||||||
exclude:
|
exclude:
|
||||||
- spring.dev.langchain4j.spring.LangChain4jAutoConfig
|
- spring.dev.langchain4j.spring.LangChain4jAutoConfig
|
||||||
@@ -22,17 +21,6 @@ spring:
|
|||||||
- spring.dev.langchain4j.azure.openai.spring.AutoConfig
|
- spring.dev.langchain4j.azure.openai.spring.AutoConfig
|
||||||
- spring.dev.langchain4j.azure.aisearch.spring.AutoConfig
|
- spring.dev.langchain4j.azure.aisearch.spring.AutoConfig
|
||||||
- spring.dev.langchain4j.anthropic.spring.AutoConfig
|
- spring.dev.langchain4j.anthropic.spring.AutoConfig
|
||||||
h2:
|
|
||||||
console:
|
|
||||||
path: /h2-console/semantic
|
|
||||||
enabled: true
|
|
||||||
datasource:
|
|
||||||
driver-class-name: org.h2.Driver
|
|
||||||
schema: classpath:db/schema-h2.sql
|
|
||||||
data: classpath:db/data-h2.sql
|
|
||||||
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false
|
|
||||||
username: root
|
|
||||||
password: semantic
|
|
||||||
|
|
||||||
mybatis:
|
mybatis:
|
||||||
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
|
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
|
||||||
@@ -40,4 +28,16 @@ mybatis:
|
|||||||
logging:
|
logging:
|
||||||
level:
|
level:
|
||||||
dev.langchain4j: DEBUG
|
dev.langchain4j: DEBUG
|
||||||
dev.ai4j.openai4j: DEBUG
|
dev.ai4j.openai4j: DEBUG
|
||||||
|
|
||||||
|
swagger:
|
||||||
|
title: 'SuperSonic API Documentation'
|
||||||
|
base:
|
||||||
|
package: com.tencent.supersonic
|
||||||
|
description: 'SuperSonic API Documentation'
|
||||||
|
url: ''
|
||||||
|
contact:
|
||||||
|
name:
|
||||||
|
email:
|
||||||
|
url: ''
|
||||||
|
version: 3.0
|
||||||
Reference in New Issue
Block a user