mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) Fix the error in Milvus query and add the option to create EmbeddingStore based on caching mode (#1310)
This commit is contained in:
@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
@@ -16,9 +15,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -26,27 +25,31 @@ import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
private static final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new ConcurrentHashMap<>();
|
||||
private final EmbeddingStoreFactory embeddingStoreFactory;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingModel embeddingModel;
|
||||
|
||||
@Override
|
||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
|
||||
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
|
||||
try {
|
||||
Response<List<Embedding>> embedAll = embeddingModel.embedAll(queries);
|
||||
embeddingStore.addAll(embedAll.content(), queries);
|
||||
} catch (Exception e) {
|
||||
log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries,
|
||||
embeddingStore.getClass().getSimpleName(), e);
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
for (TextSegment query : queries) {
|
||||
String question = query.text();
|
||||
try {
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
embeddingStore.add(embedding, query);
|
||||
} catch (Exception e) {
|
||||
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
|
||||
embeddingStore.getClass().getSimpleName(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,8 +61,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||
|
||||
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
|
||||
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
for (String queryText : queryTextsList) {
|
||||
@@ -68,7 +70,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
||||
|
||||
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
|
||||
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||
|
||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||
@@ -81,7 +83,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
||||
retrieval.setQuery(embedded.text());
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
if (MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||
if (Objects.nonNull(embedded)
|
||||
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||
metadata.putAll(embedded.metadata().toMap());
|
||||
}
|
||||
retrieval.setMetadata(metadata);
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package dev.langchain4j.chroma.spring;
|
||||
|
||||
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
|
||||
private Properties properties;
|
||||
|
||||
@@ -18,31 +15,12 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingStore create(String collectionName) {
|
||||
public EmbeddingStore createEmbeddingStore(String collectionName) {
|
||||
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
||||
EmbeddingStore embeddingStore = null;
|
||||
try {
|
||||
embeddingStore = ChromaEmbeddingStore.builder()
|
||||
.baseUrl(storeProperties.getBaseUrl())
|
||||
.collectionName(collectionName)
|
||||
.timeout(storeProperties.getTimeout())
|
||||
.build();
|
||||
} catch (Exception e) {
|
||||
log.debug("Failed to create ChromaEmbeddingStore,collectionName:{}"
|
||||
+ ", fallback to the default InMemoryEmbeddingStore method。",
|
||||
collectionName, e.getMessage());
|
||||
}
|
||||
if (Objects.isNull(embeddingStore)) {
|
||||
embeddingStore = createInMemoryEmbeddingStore(collectionName);
|
||||
}
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
private EmbeddingStore createInMemoryEmbeddingStore(String collectionName) {
|
||||
dev.langchain4j.inmemory.spring.Properties properties = new dev.langchain4j.inmemory.spring.Properties();
|
||||
dev.langchain4j.inmemory.spring.EmbeddingStoreProperties embeddingStoreProperties =
|
||||
new dev.langchain4j.inmemory.spring.EmbeddingStoreProperties();
|
||||
properties.setEmbeddingStore(embeddingStoreProperties);
|
||||
return new InMemoryEmbeddingStoreFactory(properties).create(collectionName);
|
||||
return ChromaEmbeddingStore.builder()
|
||||
.baseUrl(storeProperties.getBaseUrl())
|
||||
.collectionName(collectionName)
|
||||
.timeout(storeProperties.getTimeout())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,8 @@ package dev.langchain4j.inmemory.spring;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
@@ -15,15 +15,12 @@ import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArraySet;
|
||||
|
||||
@Slf4j
|
||||
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
|
||||
public static final String PERSISTENT_FILE_PRE = "InMemory.";
|
||||
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
|
||||
new ConcurrentHashMap<>();
|
||||
private Properties properties;
|
||||
|
||||
|
||||
@@ -32,18 +29,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized EmbeddingStore create(String collectionName) {
|
||||
InMemoryEmbeddingStore<TextSegment> embeddingStore = collectionNameToStore.get(collectionName);
|
||||
if (Objects.nonNull(embeddingStore)) {
|
||||
return embeddingStore;
|
||||
}
|
||||
embeddingStore = reloadFromPersistFile(collectionName);
|
||||
public synchronized EmbeddingStore createEmbeddingStore(String collectionName) {
|
||||
InMemoryEmbeddingStore<TextSegment> embeddingStore = reloadFromPersistFile(collectionName);
|
||||
if (Objects.isNull(embeddingStore)) {
|
||||
embeddingStore = new InMemoryEmbeddingStore();
|
||||
}
|
||||
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
|
||||
return embeddingStore;
|
||||
|
||||
}
|
||||
|
||||
private InMemoryEmbeddingStore<TextSegment> reloadFromPersistFile(String collectionName) {
|
||||
@@ -67,10 +58,10 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
}
|
||||
|
||||
public synchronized void persistFile() {
|
||||
if (MapUtils.isEmpty(collectionNameToStore)) {
|
||||
if (MapUtils.isEmpty(super.collectionNameToStore)) {
|
||||
return;
|
||||
}
|
||||
for (Map.Entry<String, InMemoryEmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
|
||||
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
|
||||
Path filePath = getPersistPath(entry.getKey());
|
||||
if (Objects.isNull(filePath)) {
|
||||
continue;
|
||||
@@ -80,7 +71,11 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
if (!Files.exists(directoryPath)) {
|
||||
Files.createDirectories(directoryPath);
|
||||
}
|
||||
entry.getValue().serializeToFile(filePath);
|
||||
if (entry.getValue() instanceof InMemoryEmbeddingStore) {
|
||||
InMemoryEmbeddingStore<TextSegment> inMemoryEmbeddingStore =
|
||||
(InMemoryEmbeddingStore) entry.getValue();
|
||||
inMemoryEmbeddingStore.serializeToFile(filePath);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("persistFile error, persistFile:" + filePath, e);
|
||||
}
|
||||
|
||||
@@ -23,4 +23,5 @@ class EmbeddingStoreProperties {
|
||||
private ConsistencyLevelEnum consistencyLevel;
|
||||
private Boolean retrieveEmbeddingsOnSearch;
|
||||
private String databaseName;
|
||||
private Boolean autoFlushOnInsert;
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
package dev.langchain4j.milvus.spring;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
||||
|
||||
public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
private final Properties properties;
|
||||
|
||||
public MilvusEmbeddingStoreFactory(Properties properties) {
|
||||
@@ -13,22 +13,23 @@ public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingStore<TextSegment> create(String collectionName) {
|
||||
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
|
||||
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
||||
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
||||
return MilvusEmbeddingStore.builder()
|
||||
.host(embeddingStore.getHost())
|
||||
.port(embeddingStore.getPort())
|
||||
.host(storeProperties.getHost())
|
||||
.port(storeProperties.getPort())
|
||||
.collectionName(collectionName)
|
||||
.dimension(embeddingStore.getDimension())
|
||||
.indexType(embeddingStore.getIndexType())
|
||||
.metricType(embeddingStore.getMetricType())
|
||||
.uri(embeddingStore.getUri())
|
||||
.token(embeddingStore.getToken())
|
||||
.username(embeddingStore.getUsername())
|
||||
.password(embeddingStore.getPassword())
|
||||
.consistencyLevel(embeddingStore.getConsistencyLevel())
|
||||
.retrieveEmbeddingsOnSearch(embeddingStore.getRetrieveEmbeddingsOnSearch())
|
||||
.databaseName(embeddingStore.getDatabaseName())
|
||||
.dimension(storeProperties.getDimension())
|
||||
.indexType(storeProperties.getIndexType())
|
||||
.metricType(storeProperties.getMetricType())
|
||||
.uri(storeProperties.getUri())
|
||||
.token(storeProperties.getToken())
|
||||
.username(storeProperties.getUsername())
|
||||
.password(storeProperties.getPassword())
|
||||
.consistencyLevel(storeProperties.getConsistencyLevel())
|
||||
.retrieveEmbeddingsOnSearch(storeProperties.getRetrieveEmbeddingsOnSearch())
|
||||
.autoFlushOnInsert(storeProperties.getAutoFlushOnInsert())
|
||||
.databaseName(storeProperties.getDatabaseName())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
protected static final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
|
||||
|
||||
public EmbeddingStore<TextSegment> create(String collectionName) {
|
||||
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);
|
||||
}
|
||||
|
||||
public abstract EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName);
|
||||
}
|
||||
@@ -5,4 +5,5 @@ import dev.langchain4j.data.segment.TextSegment;
|
||||
public interface EmbeddingStoreFactory {
|
||||
|
||||
EmbeddingStore<TextSegment> create(String collectionName);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,375 @@
|
||||
package dev.langchain4j.store.embedding.milvus;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import io.milvus.client.MilvusServiceClient;
|
||||
import io.milvus.common.clientenum.ConsistencyLevelEnum;
|
||||
import io.milvus.param.ConnectParam;
|
||||
import io.milvus.param.IndexType;
|
||||
import io.milvus.param.MetricType;
|
||||
import io.milvus.param.dml.InsertParam;
|
||||
import io.milvus.param.dml.SearchParam;
|
||||
import io.milvus.response.SearchResultsWrapper;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.createCollection;
|
||||
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.createIndex;
|
||||
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.flush;
|
||||
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.hasCollection;
|
||||
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.insert;
|
||||
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.loadCollectionInMemory;
|
||||
import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.buildSearchRequest;
|
||||
import static dev.langchain4j.store.embedding.milvus.Generator.generateRandomIds;
|
||||
import static dev.langchain4j.store.embedding.milvus.Mapper.toEmbeddingMatches;
|
||||
import static dev.langchain4j.store.embedding.milvus.Mapper.toMetadataJsons;
|
||||
import static dev.langchain4j.store.embedding.milvus.Mapper.toScalars;
|
||||
import static dev.langchain4j.store.embedding.milvus.Mapper.toVectors;
|
||||
import static io.milvus.common.clientenum.ConsistencyLevelEnum.EVENTUALLY;
|
||||
import static io.milvus.param.IndexType.FLAT;
|
||||
import static io.milvus.param.MetricType.COSINE;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
/**
|
||||
* Represents an <a href="https://milvus.io/">Milvus</a> index as an embedding store.
|
||||
* <br>
|
||||
* Supports both local and <a href="https://zilliz.com/">managed</a> Milvus instances.
|
||||
* <br>
|
||||
* Supports storing {@link Metadata} and filtering by it using a {@link Filter}
|
||||
* (provided inside an {@link EmbeddingSearchRequest}).
|
||||
*/
|
||||
public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
static final String ID_FIELD_NAME = "id";
|
||||
static final String TEXT_FIELD_NAME = "text";
|
||||
static final String METADATA_FIELD_NAME = "metadata";
|
||||
static final String VECTOR_FIELD_NAME = "vector";
|
||||
|
||||
private final MilvusServiceClient milvusClient;
|
||||
private final String collectionName;
|
||||
private final MetricType metricType;
|
||||
private final ConsistencyLevelEnum consistencyLevel;
|
||||
private final boolean retrieveEmbeddingsOnSearch;
|
||||
|
||||
private final boolean autoFlushOnInsert;
|
||||
|
||||
public MilvusEmbeddingStore(
|
||||
String host,
|
||||
Integer port,
|
||||
String collectionName,
|
||||
Integer dimension,
|
||||
IndexType indexType,
|
||||
MetricType metricType,
|
||||
String uri,
|
||||
String token,
|
||||
String username,
|
||||
String password,
|
||||
ConsistencyLevelEnum consistencyLevel,
|
||||
Boolean retrieveEmbeddingsOnSearch,
|
||||
Boolean autoFlushOnInsert,
|
||||
String databaseName
|
||||
) {
|
||||
ConnectParam.Builder connectBuilder = ConnectParam
|
||||
.newBuilder()
|
||||
.withHost(getOrDefault(host, "localhost"))
|
||||
.withPort(getOrDefault(port, 19530))
|
||||
.withUri(uri)
|
||||
.withToken(token)
|
||||
.withAuthorization(username, password);
|
||||
|
||||
if (databaseName != null) {
|
||||
connectBuilder.withDatabaseName(databaseName);
|
||||
}
|
||||
|
||||
this.milvusClient = new MilvusServiceClient(connectBuilder.build());
|
||||
this.collectionName = getOrDefault(collectionName, "default");
|
||||
this.metricType = getOrDefault(metricType, COSINE);
|
||||
this.consistencyLevel = getOrDefault(consistencyLevel, EVENTUALLY);
|
||||
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
|
||||
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
|
||||
if (!hasCollection(milvusClient, this.collectionName)) {
|
||||
createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
|
||||
createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType);
|
||||
}
|
||||
|
||||
loadCollectionInMemory(milvusClient, collectionName);
|
||||
}
|
||||
|
||||
public void dropCollection(String collectionName) {
|
||||
CollectionOperationsExecutor.dropCollection(milvusClient, collectionName);
|
||||
}
|
||||
|
||||
public String add(Embedding embedding) {
|
||||
String id = Utils.randomUUID();
|
||||
add(id, embedding);
|
||||
return id;
|
||||
}
|
||||
|
||||
public void add(String id, Embedding embedding) {
|
||||
addInternal(id, embedding, null);
|
||||
}
|
||||
|
||||
public String add(Embedding embedding, TextSegment textSegment) {
|
||||
String id = Utils.randomUUID();
|
||||
addInternal(id, embedding, textSegment);
|
||||
return id;
|
||||
}
|
||||
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
List<String> ids = generateRandomIds(embeddings.size());
|
||||
addAllInternal(ids, embeddings, null);
|
||||
return ids;
|
||||
}
|
||||
|
||||
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
|
||||
List<String> ids = generateRandomIds(embeddings.size());
|
||||
addAllInternal(ids, embeddings, embedded);
|
||||
return ids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
|
||||
|
||||
SearchParam searchParam = buildSearchRequest(
|
||||
collectionName,
|
||||
embeddingSearchRequest.queryEmbedding().vectorAsList(),
|
||||
embeddingSearchRequest.filter(),
|
||||
embeddingSearchRequest.maxResults(),
|
||||
metricType,
|
||||
consistencyLevel
|
||||
);
|
||||
|
||||
SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(
|
||||
milvusClient,
|
||||
resultsWrapper,
|
||||
collectionName,
|
||||
consistencyLevel,
|
||||
retrieveEmbeddingsOnSearch
|
||||
);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> result = matches.stream()
|
||||
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
|
||||
.collect(toList());
|
||||
|
||||
return new EmbeddingSearchResult<>(result);
|
||||
}
|
||||
|
||||
private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
|
||||
addAllInternal(
|
||||
singletonList(id),
|
||||
singletonList(embedding),
|
||||
textSegment == null ? null : singletonList(textSegment)
|
||||
);
|
||||
}
|
||||
|
||||
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
|
||||
List<InsertParam.Field> fields = new ArrayList<>();
|
||||
fields.add(new InsertParam.Field(ID_FIELD_NAME, ids));
|
||||
fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size())));
|
||||
fields.add(new InsertParam.Field(METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
|
||||
fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings)));
|
||||
|
||||
insert(milvusClient, collectionName, fields);
|
||||
if (autoFlushOnInsert) {
|
||||
flush(this.milvusClient, this.collectionName);
|
||||
}
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String host;
|
||||
private Integer port;
|
||||
private String collectionName;
|
||||
private Integer dimension;
|
||||
private IndexType indexType;
|
||||
private MetricType metricType;
|
||||
private String uri;
|
||||
private String token;
|
||||
private String username;
|
||||
private String password;
|
||||
private ConsistencyLevelEnum consistencyLevel;
|
||||
private Boolean retrieveEmbeddingsOnSearch;
|
||||
private Boolean autoFlushOnInsert;
|
||||
private String databaseName;
|
||||
|
||||
/**
|
||||
* @param host The host of the self-managed Milvus instance.
|
||||
* Default value: "localhost".
|
||||
* @return builder
|
||||
*/
|
||||
public Builder host(String host) {
|
||||
this.host = host;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param port The port of the self-managed Milvus instance.
|
||||
* Default value: 19530.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder port(Integer port) {
|
||||
this.port = port;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param collectionName The name of the Milvus collection.
|
||||
* If there is no such collection yet, it will be created automatically.
|
||||
* Default value: "default".
|
||||
* @return builder
|
||||
*/
|
||||
public Builder collectionName(String collectionName) {
|
||||
this.collectionName = collectionName;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param dimension The dimension of the embedding vector. (e.g. 384)
|
||||
* Mandatory if a new collection should be created.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder dimension(Integer dimension) {
|
||||
this.dimension = dimension;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param indexType The type of the index.
|
||||
* Default value: FLAT.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder indexType(IndexType indexType) {
|
||||
this.indexType = indexType;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param metricType The type of the metric used for similarity search.
|
||||
* Default value: COSINE.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder metricType(MetricType metricType) {
|
||||
this.metricType = metricType;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param uri The URI of the managed Milvus instance. (e.g. "https://xxx.api.gcp-us-west1.zillizcloud.com")
|
||||
* @return builder
|
||||
*/
|
||||
public Builder uri(String uri) {
|
||||
this.uri = uri;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param token The token (API key) of the managed Milvus instance.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder token(String token) {
|
||||
this.token = token;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param username The username. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder username(String username) {
|
||||
this.username = username;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param password The password. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder password(String password) {
|
||||
this.password = password;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param consistencyLevel The consistency level used by Milvus.
|
||||
* Default value: EVENTUALLY.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder consistencyLevel(ConsistencyLevelEnum consistencyLevel) {
|
||||
this.consistencyLevel = consistencyLevel;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling findRelevant()),
|
||||
* the embedding itself is not retrieved.
|
||||
* To retrieve the embedding, an additional query is required.
|
||||
* Setting this parameter to "true" will ensure that embedding is retrieved.
|
||||
* Be aware that this will impact the performance of the search.
|
||||
* Default value: false.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) {
|
||||
this.retrieveEmbeddingsOnSearch = retrieveEmbeddingsOnSearch;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param autoFlushOnInsert Whether to automatically flush after each insert
|
||||
* ({@code add(...)} or {@code addAll(...)} methods).
|
||||
* Default value: false.
|
||||
* More info can be found
|
||||
* <a href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) {
|
||||
this.autoFlushOnInsert = autoFlushOnInsert;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param databaseName Milvus name of database.
|
||||
* Default value: null. In this case default Milvus database name will be used.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder databaseName(String databaseName) {
|
||||
this.databaseName = databaseName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public MilvusEmbeddingStore build() {
|
||||
return new MilvusEmbeddingStore(
|
||||
host,
|
||||
port,
|
||||
collectionName,
|
||||
dimension,
|
||||
indexType,
|
||||
metricType,
|
||||
uri,
|
||||
token,
|
||||
username,
|
||||
password,
|
||||
consistencyLevel,
|
||||
retrieveEmbeddingsOnSearch,
|
||||
autoFlushOnInsert,
|
||||
databaseName
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,19 +24,19 @@ langchain4j:
|
||||
embedding-model:
|
||||
model-name: bge-small-zh
|
||||
|
||||
embedding-store:
|
||||
persist-path: /tmp
|
||||
# embedding-store:
|
||||
# persist-path: /tmp
|
||||
|
||||
# chroma:
|
||||
# embedding-store:
|
||||
# baseUrl: http://0.0.0.0:8000
|
||||
# timeout: 120s
|
||||
|
||||
# milvus:
|
||||
# embedding-store:
|
||||
# host: localhost
|
||||
# port: 2379
|
||||
# uri: http://0.0.0.0:19530
|
||||
# token: demo
|
||||
# dimension: 512
|
||||
# timeout: 120s
|
||||
milvus:
|
||||
embedding-store:
|
||||
host: localhost
|
||||
port: 2379
|
||||
uri: http://0.0.0.0:19530
|
||||
token: demo
|
||||
dimension: 512
|
||||
timeout: 120s
|
||||
@@ -0,0 +1,15 @@
|
||||
spring:
|
||||
datasource:
|
||||
driver-class-name: org.h2.Driver
|
||||
schema: classpath:db/schema-h2.sql
|
||||
data: classpath:db/data-h2.sql
|
||||
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false
|
||||
username: root
|
||||
password: semantic
|
||||
h2:
|
||||
console:
|
||||
path: /h2-console/semantic
|
||||
enabled: true
|
||||
config:
|
||||
import:
|
||||
- classpath:langchain4j-local.yaml
|
||||
@@ -13,7 +13,6 @@ spring:
|
||||
config:
|
||||
import:
|
||||
- classpath:s2-config.yaml
|
||||
- classpath:langchain4j-local.yaml
|
||||
autoconfigure:
|
||||
exclude:
|
||||
- spring.dev.langchain4j.spring.LangChain4jAutoConfig
|
||||
@@ -22,17 +21,6 @@ spring:
|
||||
- spring.dev.langchain4j.azure.openai.spring.AutoConfig
|
||||
- spring.dev.langchain4j.azure.aisearch.spring.AutoConfig
|
||||
- spring.dev.langchain4j.anthropic.spring.AutoConfig
|
||||
h2:
|
||||
console:
|
||||
path: /h2-console/semantic
|
||||
enabled: true
|
||||
datasource:
|
||||
driver-class-name: org.h2.Driver
|
||||
schema: classpath:db/schema-h2.sql
|
||||
data: classpath:db/data-h2.sql
|
||||
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false
|
||||
username: root
|
||||
password: semantic
|
||||
|
||||
mybatis:
|
||||
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
|
||||
@@ -40,4 +28,16 @@ mybatis:
|
||||
logging:
|
||||
level:
|
||||
dev.langchain4j: DEBUG
|
||||
dev.ai4j.openai4j: DEBUG
|
||||
dev.ai4j.openai4j: DEBUG
|
||||
|
||||
swagger:
|
||||
title: 'SuperSonic API Documentation'
|
||||
base:
|
||||
package: com.tencent.supersonic
|
||||
description: 'SuperSonic API Documentation'
|
||||
url: ''
|
||||
contact:
|
||||
name:
|
||||
email:
|
||||
url: ''
|
||||
version: 3.0
|
||||
Reference in New Issue
Block a user