diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 28d231863..3e19db67d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -4,7 +4,6 @@ import com.tencent.supersonic.common.service.EmbeddingService; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.model.output.Response; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchResult; @@ -16,9 +15,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.TextSegmentConvert; import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; -import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.MapUtils; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.ArrayList; @@ -26,27 +25,31 @@ import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Objects; import java.util.stream.Collectors; @Service @Slf4j -@RequiredArgsConstructor public class EmbeddingServiceImpl implements EmbeddingService { - private static final Map> embeddingStoreMap = new ConcurrentHashMap<>(); - private final EmbeddingStoreFactory embeddingStoreFactory; - private final EmbeddingModel embeddingModel; + + @Autowired + private EmbeddingStoreFactory embeddingStoreFactory; + + @Autowired + private EmbeddingModel embeddingModel; @Override public void addQuery(String collectionName, List queries) { - EmbeddingStore embeddingStore = embeddingStoreMap - .computeIfAbsent(collectionName, embeddingStoreFactory::create); - try { - Response> embedAll = embeddingModel.embedAll(queries); - embeddingStore.addAll(embedAll.content(), queries); - } catch (Exception e) { - log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries, - embeddingStore.getClass().getSimpleName(), e); + EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); + for (TextSegment query : queries) { + String question = query.text(); + try { + Embedding embedding = embeddingModel.embed(question).content(); + embeddingStore.add(embedding, query); + } catch (Exception e) { + log.error("embeddingModel embed error question: {}, embeddingStore: {}", question, + embeddingStore.getClass().getSimpleName(), e); + } } } @@ -58,8 +61,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { public List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { List results = new ArrayList<>(); - EmbeddingStore embeddingStore = embeddingStoreMap - .computeIfAbsent(collectionName, embeddingStoreFactory::create); + EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); List queryTextsList = retrieveQuery.getQueryTextsList(); Map filterCondition = retrieveQuery.getFilterCondition(); for (String queryText : queryTextsList) { @@ -68,7 +70,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() .queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); - EmbeddingSearchResult result = embeddingStore.search(request); + EmbeddingSearchResult result = embeddingStore.search(request); List> relevant = result.matches(); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); @@ -81,7 +83,8 @@ public class EmbeddingServiceImpl implements EmbeddingService { retrieval.setId(TextSegmentConvert.getQueryId(embedded)); retrieval.setQuery(embedded.text()); Map metadata = new HashMap<>(); - if (MapUtils.isNotEmpty(embedded.metadata().toMap())) { + if (Objects.nonNull(embedded) + && MapUtils.isNotEmpty(embedded.metadata().toMap())) { metadata.putAll(embedded.metadata().toMap()); } retrieval.setMetadata(metadata); diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java index 8a2dd0f05..5e47576cc 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java @@ -1,15 +1,12 @@ package dev.langchain4j.chroma.spring; -import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; +import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore; import lombok.extern.slf4j.Slf4j; -import java.util.Objects; - @Slf4j -public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory { +public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { private Properties properties; @@ -18,31 +15,12 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory { } @Override - public EmbeddingStore create(String collectionName) { + public EmbeddingStore createEmbeddingStore(String collectionName) { EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore(); - EmbeddingStore embeddingStore = null; - try { - embeddingStore = ChromaEmbeddingStore.builder() - .baseUrl(storeProperties.getBaseUrl()) - .collectionName(collectionName) - .timeout(storeProperties.getTimeout()) - .build(); - } catch (Exception e) { - log.debug("Failed to create ChromaEmbeddingStore,collectionName:{}" - + ", fallback to the default InMemoryEmbeddingStore method。", - collectionName, e.getMessage()); - } - if (Objects.isNull(embeddingStore)) { - embeddingStore = createInMemoryEmbeddingStore(collectionName); - } - return embeddingStore; - } - - private EmbeddingStore createInMemoryEmbeddingStore(String collectionName) { - dev.langchain4j.inmemory.spring.Properties properties = new dev.langchain4j.inmemory.spring.Properties(); - dev.langchain4j.inmemory.spring.EmbeddingStoreProperties embeddingStoreProperties = - new dev.langchain4j.inmemory.spring.EmbeddingStoreProperties(); - properties.setEmbeddingStore(embeddingStoreProperties); - return new InMemoryEmbeddingStoreFactory(properties).create(collectionName); + return ChromaEmbeddingStore.builder() + .baseUrl(storeProperties.getBaseUrl()) + .collectionName(collectionName) + .timeout(storeProperties.getTimeout()) + .build(); } } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java index 969347619..05bee26ba 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java @@ -3,8 +3,8 @@ package dev.langchain4j.inmemory.spring; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.MapUtils; @@ -15,15 +15,12 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Map; import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArraySet; @Slf4j -public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { +public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { public static final String PERSISTENT_FILE_PRE = "InMemory."; - private static Map> collectionNameToStore = - new ConcurrentHashMap<>(); private Properties properties; @@ -32,18 +29,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { } @Override - public synchronized EmbeddingStore create(String collectionName) { - InMemoryEmbeddingStore embeddingStore = collectionNameToStore.get(collectionName); - if (Objects.nonNull(embeddingStore)) { - return embeddingStore; - } - embeddingStore = reloadFromPersistFile(collectionName); + public synchronized EmbeddingStore createEmbeddingStore(String collectionName) { + InMemoryEmbeddingStore embeddingStore = reloadFromPersistFile(collectionName); if (Objects.isNull(embeddingStore)) { embeddingStore = new InMemoryEmbeddingStore(); } - collectionNameToStore.putIfAbsent(collectionName, embeddingStore); return embeddingStore; - } private InMemoryEmbeddingStore reloadFromPersistFile(String collectionName) { @@ -67,10 +58,10 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { } public synchronized void persistFile() { - if (MapUtils.isEmpty(collectionNameToStore)) { + if (MapUtils.isEmpty(super.collectionNameToStore)) { return; } - for (Map.Entry> entry : collectionNameToStore.entrySet()) { + for (Map.Entry> entry : collectionNameToStore.entrySet()) { Path filePath = getPersistPath(entry.getKey()); if (Objects.isNull(filePath)) { continue; @@ -80,7 +71,11 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { if (!Files.exists(directoryPath)) { Files.createDirectories(directoryPath); } - entry.getValue().serializeToFile(filePath); + if (entry.getValue() instanceof InMemoryEmbeddingStore) { + InMemoryEmbeddingStore inMemoryEmbeddingStore = + (InMemoryEmbeddingStore) entry.getValue(); + inMemoryEmbeddingStore.serializeToFile(filePath); + } } catch (Exception e) { log.error("persistFile error, persistFile:" + filePath, e); } diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/EmbeddingStoreProperties.java b/common/src/main/java/dev/langchain4j/milvus/spring/EmbeddingStoreProperties.java index a9a595df9..0dab5f45a 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/EmbeddingStoreProperties.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/EmbeddingStoreProperties.java @@ -23,4 +23,5 @@ class EmbeddingStoreProperties { private ConsistencyLevelEnum consistencyLevel; private Boolean retrieveEmbeddingsOnSearch; private String databaseName; + private Boolean autoFlushOnInsert; } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java index 0ca780057..873c5f129 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java @@ -1,11 +1,11 @@ package dev.langchain4j.milvus.spring; import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; -public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory { +public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { private final Properties properties; public MilvusEmbeddingStoreFactory(Properties properties) { @@ -13,22 +13,23 @@ public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory { } @Override - public EmbeddingStore create(String collectionName) { - EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore(); + public EmbeddingStore createEmbeddingStore(String collectionName) { + EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore(); return MilvusEmbeddingStore.builder() - .host(embeddingStore.getHost()) - .port(embeddingStore.getPort()) + .host(storeProperties.getHost()) + .port(storeProperties.getPort()) .collectionName(collectionName) - .dimension(embeddingStore.getDimension()) - .indexType(embeddingStore.getIndexType()) - .metricType(embeddingStore.getMetricType()) - .uri(embeddingStore.getUri()) - .token(embeddingStore.getToken()) - .username(embeddingStore.getUsername()) - .password(embeddingStore.getPassword()) - .consistencyLevel(embeddingStore.getConsistencyLevel()) - .retrieveEmbeddingsOnSearch(embeddingStore.getRetrieveEmbeddingsOnSearch()) - .databaseName(embeddingStore.getDatabaseName()) + .dimension(storeProperties.getDimension()) + .indexType(storeProperties.getIndexType()) + .metricType(storeProperties.getMetricType()) + .uri(storeProperties.getUri()) + .token(storeProperties.getToken()) + .username(storeProperties.getUsername()) + .password(storeProperties.getPassword()) + .consistencyLevel(storeProperties.getConsistencyLevel()) + .retrieveEmbeddingsOnSearch(storeProperties.getRetrieveEmbeddingsOnSearch()) + .autoFlushOnInsert(storeProperties.getAutoFlushOnInsert()) + .databaseName(storeProperties.getDatabaseName()) .build(); } } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/store/embedding/BaseEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/store/embedding/BaseEmbeddingStoreFactory.java new file mode 100644 index 000000000..2f895b1b9 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/store/embedding/BaseEmbeddingStoreFactory.java @@ -0,0 +1,16 @@ +package dev.langchain4j.store.embedding; + +import dev.langchain4j.data.segment.TextSegment; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory { + protected static final Map> collectionNameToStore = new ConcurrentHashMap<>(); + + public EmbeddingStore create(String collectionName) { + return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore); + } + + public abstract EmbeddingStore createEmbeddingStore(String collectionName); +} \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactory.java index d818ea80d..c808bbe6c 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactory.java @@ -5,4 +5,5 @@ import dev.langchain4j.data.segment.TextSegment; public interface EmbeddingStoreFactory { EmbeddingStore create(String collectionName); + } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java b/common/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java new file mode 100644 index 000000000..04b6bbef5 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java @@ -0,0 +1,375 @@ +package dev.langchain4j.store.embedding.milvus; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.internal.Utils; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.filter.Filter; +import io.milvus.client.MilvusServiceClient; +import io.milvus.common.clientenum.ConsistencyLevelEnum; +import io.milvus.param.ConnectParam; +import io.milvus.param.IndexType; +import io.milvus.param.MetricType; +import io.milvus.param.dml.InsertParam; +import io.milvus.param.dml.SearchParam; +import io.milvus.response.SearchResultsWrapper; + +import java.util.ArrayList; +import java.util.List; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.createCollection; +import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.createIndex; +import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.flush; +import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.hasCollection; +import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.insert; +import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.loadCollectionInMemory; +import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.buildSearchRequest; +import static dev.langchain4j.store.embedding.milvus.Generator.generateRandomIds; +import static dev.langchain4j.store.embedding.milvus.Mapper.toEmbeddingMatches; +import static dev.langchain4j.store.embedding.milvus.Mapper.toMetadataJsons; +import static dev.langchain4j.store.embedding.milvus.Mapper.toScalars; +import static dev.langchain4j.store.embedding.milvus.Mapper.toVectors; +import static io.milvus.common.clientenum.ConsistencyLevelEnum.EVENTUALLY; +import static io.milvus.param.IndexType.FLAT; +import static io.milvus.param.MetricType.COSINE; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; + +/** + * Represents an Milvus index as an embedding store. + *
+ * Supports both local and managed Milvus instances. + *
+ * Supports storing {@link Metadata} and filtering by it using a {@link Filter} + * (provided inside an {@link EmbeddingSearchRequest}). + */ +public class MilvusEmbeddingStore implements EmbeddingStore { + + static final String ID_FIELD_NAME = "id"; + static final String TEXT_FIELD_NAME = "text"; + static final String METADATA_FIELD_NAME = "metadata"; + static final String VECTOR_FIELD_NAME = "vector"; + + private final MilvusServiceClient milvusClient; + private final String collectionName; + private final MetricType metricType; + private final ConsistencyLevelEnum consistencyLevel; + private final boolean retrieveEmbeddingsOnSearch; + + private final boolean autoFlushOnInsert; + + public MilvusEmbeddingStore( + String host, + Integer port, + String collectionName, + Integer dimension, + IndexType indexType, + MetricType metricType, + String uri, + String token, + String username, + String password, + ConsistencyLevelEnum consistencyLevel, + Boolean retrieveEmbeddingsOnSearch, + Boolean autoFlushOnInsert, + String databaseName + ) { + ConnectParam.Builder connectBuilder = ConnectParam + .newBuilder() + .withHost(getOrDefault(host, "localhost")) + .withPort(getOrDefault(port, 19530)) + .withUri(uri) + .withToken(token) + .withAuthorization(username, password); + + if (databaseName != null) { + connectBuilder.withDatabaseName(databaseName); + } + + this.milvusClient = new MilvusServiceClient(connectBuilder.build()); + this.collectionName = getOrDefault(collectionName, "default"); + this.metricType = getOrDefault(metricType, COSINE); + this.consistencyLevel = getOrDefault(consistencyLevel, EVENTUALLY); + this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false); + this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false); + if (!hasCollection(milvusClient, this.collectionName)) { + createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension")); + createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType); + } + + loadCollectionInMemory(milvusClient, collectionName); + } + + public void dropCollection(String collectionName) { + CollectionOperationsExecutor.dropCollection(milvusClient, collectionName); + } + + public String add(Embedding embedding) { + String id = Utils.randomUUID(); + add(id, embedding); + return id; + } + + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + public String add(Embedding embedding, TextSegment textSegment) { + String id = Utils.randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + public List addAll(List embeddings) { + List ids = generateRandomIds(embeddings.size()); + addAllInternal(ids, embeddings, null); + return ids; + } + + public List addAll(List embeddings, List embedded) { + List ids = generateRandomIds(embeddings.size()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + @Override + public EmbeddingSearchResult search(EmbeddingSearchRequest embeddingSearchRequest) { + + SearchParam searchParam = buildSearchRequest( + collectionName, + embeddingSearchRequest.queryEmbedding().vectorAsList(), + embeddingSearchRequest.filter(), + embeddingSearchRequest.maxResults(), + metricType, + consistencyLevel + ); + + SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam); + + List> matches = toEmbeddingMatches( + milvusClient, + resultsWrapper, + collectionName, + consistencyLevel, + retrieveEmbeddingsOnSearch + ); + + List> result = matches.stream() + .filter(match -> match.score() >= embeddingSearchRequest.minScore()) + .collect(toList()); + + return new EmbeddingSearchResult<>(result); + } + + private void addInternal(String id, Embedding embedding, TextSegment textSegment) { + addAllInternal( + singletonList(id), + singletonList(embedding), + textSegment == null ? null : singletonList(textSegment) + ); + } + + private void addAllInternal(List ids, List embeddings, List textSegments) { + List fields = new ArrayList<>(); + fields.add(new InsertParam.Field(ID_FIELD_NAME, ids)); + fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size()))); + fields.add(new InsertParam.Field(METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size()))); + fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings))); + + insert(milvusClient, collectionName, fields); + if (autoFlushOnInsert) { + flush(this.milvusClient, this.collectionName); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String host; + private Integer port; + private String collectionName; + private Integer dimension; + private IndexType indexType; + private MetricType metricType; + private String uri; + private String token; + private String username; + private String password; + private ConsistencyLevelEnum consistencyLevel; + private Boolean retrieveEmbeddingsOnSearch; + private Boolean autoFlushOnInsert; + private String databaseName; + + /** + * @param host The host of the self-managed Milvus instance. + * Default value: "localhost". + * @return builder + */ + public Builder host(String host) { + this.host = host; + return this; + } + + /** + * @param port The port of the self-managed Milvus instance. + * Default value: 19530. + * @return builder + */ + public Builder port(Integer port) { + this.port = port; + return this; + } + + /** + * @param collectionName The name of the Milvus collection. + * If there is no such collection yet, it will be created automatically. + * Default value: "default". + * @return builder + */ + public Builder collectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * @param dimension The dimension of the embedding vector. (e.g. 384) + * Mandatory if a new collection should be created. + * @return builder + */ + public Builder dimension(Integer dimension) { + this.dimension = dimension; + return this; + } + + /** + * @param indexType The type of the index. + * Default value: FLAT. + * @return builder + */ + public Builder indexType(IndexType indexType) { + this.indexType = indexType; + return this; + } + + /** + * @param metricType The type of the metric used for similarity search. + * Default value: COSINE. + * @return builder + */ + public Builder metricType(MetricType metricType) { + this.metricType = metricType; + return this; + } + + /** + * @param uri The URI of the managed Milvus instance. (e.g. "https://xxx.api.gcp-us-west1.zillizcloud.com") + * @return builder + */ + public Builder uri(String uri) { + this.uri = uri; + return this; + } + + /** + * @param token The token (API key) of the managed Milvus instance. + * @return builder + */ + public Builder token(String token) { + this.token = token; + return this; + } + + /** + * @param username The username. See details here. + * @return builder + */ + public Builder username(String username) { + this.username = username; + return this; + } + + /** + * @param password The password. See details here. + * @return builder + */ + public Builder password(String password) { + this.password = password; + return this; + } + + /** + * @param consistencyLevel The consistency level used by Milvus. + * Default value: EVENTUALLY. + * @return builder + */ + public Builder consistencyLevel(ConsistencyLevelEnum consistencyLevel) { + this.consistencyLevel = consistencyLevel; + return this; + } + + /** + * @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling findRelevant()), + * the embedding itself is not retrieved. + * To retrieve the embedding, an additional query is required. + * Setting this parameter to "true" will ensure that embedding is retrieved. + * Be aware that this will impact the performance of the search. + * Default value: false. + * @return builder + */ + public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) { + this.retrieveEmbeddingsOnSearch = retrieveEmbeddingsOnSearch; + return this; + } + + /** + * @param autoFlushOnInsert Whether to automatically flush after each insert + * ({@code add(...)} or {@code addAll(...)} methods). + * Default value: false. + * More info can be found + * here. + * @return builder + */ + public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) { + this.autoFlushOnInsert = autoFlushOnInsert; + return this; + } + + /** + * @param databaseName Milvus name of database. + * Default value: null. In this case default Milvus database name will be used. + * @return builder + */ + public Builder databaseName(String databaseName) { + this.databaseName = databaseName; + return this; + } + + public MilvusEmbeddingStore build() { + return new MilvusEmbeddingStore( + host, + port, + collectionName, + dimension, + indexType, + metricType, + uri, + token, + username, + password, + consistencyLevel, + retrieveEmbeddingsOnSearch, + autoFlushOnInsert, + databaseName + ); + } + } +} diff --git a/launchers/standalone/src/main/resources/langchain4j-local.yaml b/launchers/standalone/src/main/resources/langchain4j-local.yaml index f0573990f..e25dd9cb0 100644 --- a/launchers/standalone/src/main/resources/langchain4j-local.yaml +++ b/launchers/standalone/src/main/resources/langchain4j-local.yaml @@ -24,19 +24,19 @@ langchain4j: embedding-model: model-name: bge-small-zh - embedding-store: - persist-path: /tmp +# embedding-store: +# persist-path: /tmp # chroma: # embedding-store: # baseUrl: http://0.0.0.0:8000 # timeout: 120s -# milvus: -# embedding-store: -# host: localhost -# port: 2379 -# uri: http://0.0.0.0:19530 -# token: demo -# dimension: 512 -# timeout: 120s \ No newline at end of file + milvus: + embedding-store: + host: localhost + port: 2379 + uri: http://0.0.0.0:19530 + token: demo + dimension: 512 + timeout: 120s \ No newline at end of file diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml new file mode 100644 index 000000000..69890c191 --- /dev/null +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -0,0 +1,15 @@ +spring: + datasource: + driver-class-name: org.h2.Driver + schema: classpath:db/schema-h2.sql + data: classpath:db/data-h2.sql + url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false + username: root + password: semantic + h2: + console: + path: /h2-console/semantic + enabled: true + config: + import: + - classpath:langchain4j-local.yaml \ No newline at end of file diff --git a/launchers/standalone/src/test/resources/application.yaml b/launchers/standalone/src/test/resources/application.yaml index 84097eaaa..2aa5ac6f2 100644 --- a/launchers/standalone/src/test/resources/application.yaml +++ b/launchers/standalone/src/test/resources/application.yaml @@ -13,7 +13,6 @@ spring: config: import: - classpath:s2-config.yaml - - classpath:langchain4j-local.yaml autoconfigure: exclude: - spring.dev.langchain4j.spring.LangChain4jAutoConfig @@ -22,17 +21,6 @@ spring: - spring.dev.langchain4j.azure.openai.spring.AutoConfig - spring.dev.langchain4j.azure.aisearch.spring.AutoConfig - spring.dev.langchain4j.anthropic.spring.AutoConfig - h2: - console: - path: /h2-console/semantic - enabled: true - datasource: - driver-class-name: org.h2.Driver - schema: classpath:db/schema-h2.sql - data: classpath:db/data-h2.sql - url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false - username: root - password: semantic mybatis: mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml @@ -40,4 +28,16 @@ mybatis: logging: level: dev.langchain4j: DEBUG - dev.ai4j.openai4j: DEBUG \ No newline at end of file + dev.ai4j.openai4j: DEBUG + +swagger: + title: 'SuperSonic API Documentation' + base: + package: com.tencent.supersonic + description: 'SuperSonic API Documentation' + url: '' + contact: + name: + email: + url: '' + version: 3.0 \ No newline at end of file