(improvement)(chat) Fix the error in Milvus query and add the option to create EmbeddingStore based on caching mode (#1310)

This commit is contained in:
lexluo09
2024-07-01 16:29:43 +08:00
committed by GitHub
parent 37d08007c4
commit 7773442fbf
11 changed files with 489 additions and 104 deletions

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
@@ -16,9 +15,9 @@ import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
@@ -26,27 +25,31 @@ import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Objects;
import java.util.stream.Collectors;
@Service
@Slf4j
@RequiredArgsConstructor
public class EmbeddingServiceImpl implements EmbeddingService {
private static final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new ConcurrentHashMap<>();
private final EmbeddingStoreFactory embeddingStoreFactory;
private final EmbeddingModel embeddingModel;
@Autowired
private EmbeddingStoreFactory embeddingStoreFactory;
@Autowired
private EmbeddingModel embeddingModel;
@Override
public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
try {
Response<List<Embedding>> embedAll = embeddingModel.embedAll(queries);
embeddingStore.addAll(embedAll.content(), queries);
} catch (Exception e) {
log.error("embeddingModel embed error queries: {}, embeddingStore: {}", queries,
embeddingStore.getClass().getSimpleName(), e);
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
for (TextSegment query : queries) {
String question = query.text();
try {
Embedding embedding = embeddingModel.embed(question).content();
embeddingStore.add(embedding, query);
} catch (Exception e) {
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
embeddingStore.getClass().getSimpleName(), e);
}
}
}
@@ -58,8 +61,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
List<RetrieveQueryResult> results = new ArrayList<>();
EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap
.computeIfAbsent(collectionName, embeddingStoreFactory::create);
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) {
@@ -68,7 +70,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
EmbeddingSearchResult result = embeddingStore.search(request);
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
@@ -81,7 +83,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text());
Map<String, Object> metadata = new HashMap<>();
if (MapUtils.isNotEmpty(embedded.metadata().toMap())) {
if (Objects.nonNull(embedded)
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
metadata.putAll(embedded.metadata().toMap());
}
retrieval.setMetadata(metadata);

View File

@@ -1,15 +1,12 @@
package dev.langchain4j.chroma.spring;
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
@Slf4j
public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private Properties properties;
@@ -18,31 +15,12 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
}
@Override
public EmbeddingStore create(String collectionName) {
public EmbeddingStore createEmbeddingStore(String collectionName) {
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
EmbeddingStore embeddingStore = null;
try {
embeddingStore = ChromaEmbeddingStore.builder()
.baseUrl(storeProperties.getBaseUrl())
.collectionName(collectionName)
.timeout(storeProperties.getTimeout())
.build();
} catch (Exception e) {
log.debug("Failed to create ChromaEmbeddingStore,collectionName:{}"
+ ", fallback to the default InMemoryEmbeddingStore method。",
collectionName, e.getMessage());
}
if (Objects.isNull(embeddingStore)) {
embeddingStore = createInMemoryEmbeddingStore(collectionName);
}
return embeddingStore;
}
private EmbeddingStore createInMemoryEmbeddingStore(String collectionName) {
dev.langchain4j.inmemory.spring.Properties properties = new dev.langchain4j.inmemory.spring.Properties();
dev.langchain4j.inmemory.spring.EmbeddingStoreProperties embeddingStoreProperties =
new dev.langchain4j.inmemory.spring.EmbeddingStoreProperties();
properties.setEmbeddingStore(embeddingStoreProperties);
return new InMemoryEmbeddingStoreFactory(properties).create(collectionName);
return ChromaEmbeddingStore.builder()
.baseUrl(storeProperties.getBaseUrl())
.collectionName(collectionName)
.timeout(storeProperties.getTimeout())
.build();
}
}

View File

@@ -3,8 +3,8 @@ package dev.langchain4j.inmemory.spring;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
@@ -15,15 +15,12 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
@Slf4j
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>();
private Properties properties;
@@ -32,18 +29,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
}
@Override
public synchronized EmbeddingStore create(String collectionName) {
InMemoryEmbeddingStore<TextSegment> embeddingStore = collectionNameToStore.get(collectionName);
if (Objects.nonNull(embeddingStore)) {
return embeddingStore;
}
embeddingStore = reloadFromPersistFile(collectionName);
public synchronized EmbeddingStore createEmbeddingStore(String collectionName) {
InMemoryEmbeddingStore<TextSegment> embeddingStore = reloadFromPersistFile(collectionName);
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
}
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
return embeddingStore;
}
private InMemoryEmbeddingStore<TextSegment> reloadFromPersistFile(String collectionName) {
@@ -67,10 +58,10 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
}
public synchronized void persistFile() {
if (MapUtils.isEmpty(collectionNameToStore)) {
if (MapUtils.isEmpty(super.collectionNameToStore)) {
return;
}
for (Map.Entry<String, InMemoryEmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
Path filePath = getPersistPath(entry.getKey());
if (Objects.isNull(filePath)) {
continue;
@@ -80,7 +71,11 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
if (!Files.exists(directoryPath)) {
Files.createDirectories(directoryPath);
}
entry.getValue().serializeToFile(filePath);
if (entry.getValue() instanceof InMemoryEmbeddingStore) {
InMemoryEmbeddingStore<TextSegment> inMemoryEmbeddingStore =
(InMemoryEmbeddingStore) entry.getValue();
inMemoryEmbeddingStore.serializeToFile(filePath);
}
} catch (Exception e) {
log.error("persistFile error, persistFile:" + filePath, e);
}

View File

@@ -23,4 +23,5 @@ class EmbeddingStoreProperties {
private ConsistencyLevelEnum consistencyLevel;
private Boolean retrieveEmbeddingsOnSearch;
private String databaseName;
private Boolean autoFlushOnInsert;
}

View File

@@ -1,11 +1,11 @@
package dev.langchain4j.milvus.spring;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private final Properties properties;
public MilvusEmbeddingStoreFactory(Properties properties) {
@@ -13,22 +13,23 @@ public class MilvusEmbeddingStoreFactory implements EmbeddingStoreFactory {
}
@Override
public EmbeddingStore<TextSegment> create(String collectionName) {
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
return MilvusEmbeddingStore.builder()
.host(embeddingStore.getHost())
.port(embeddingStore.getPort())
.host(storeProperties.getHost())
.port(storeProperties.getPort())
.collectionName(collectionName)
.dimension(embeddingStore.getDimension())
.indexType(embeddingStore.getIndexType())
.metricType(embeddingStore.getMetricType())
.uri(embeddingStore.getUri())
.token(embeddingStore.getToken())
.username(embeddingStore.getUsername())
.password(embeddingStore.getPassword())
.consistencyLevel(embeddingStore.getConsistencyLevel())
.retrieveEmbeddingsOnSearch(embeddingStore.getRetrieveEmbeddingsOnSearch())
.databaseName(embeddingStore.getDatabaseName())
.dimension(storeProperties.getDimension())
.indexType(storeProperties.getIndexType())
.metricType(storeProperties.getMetricType())
.uri(storeProperties.getUri())
.token(storeProperties.getToken())
.username(storeProperties.getUsername())
.password(storeProperties.getPassword())
.consistencyLevel(storeProperties.getConsistencyLevel())
.retrieveEmbeddingsOnSearch(storeProperties.getRetrieveEmbeddingsOnSearch())
.autoFlushOnInsert(storeProperties.getAutoFlushOnInsert())
.databaseName(storeProperties.getDatabaseName())
.build();
}
}

View File

@@ -0,0 +1,16 @@
package dev.langchain4j.store.embedding;
import dev.langchain4j.data.segment.TextSegment;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
protected static final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
public EmbeddingStore<TextSegment> create(String collectionName) {
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);
}
public abstract EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName);
}

View File

@@ -5,4 +5,5 @@ import dev.langchain4j.data.segment.TextSegment;
public interface EmbeddingStoreFactory {
EmbeddingStore<TextSegment> create(String collectionName);
}

View File

@@ -0,0 +1,375 @@
package dev.langchain4j.store.embedding.milvus;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.param.ConnectParam;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.response.SearchResultsWrapper;
import java.util.ArrayList;
import java.util.List;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.createCollection;
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.createIndex;
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.flush;
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.hasCollection;
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.insert;
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.loadCollectionInMemory;
import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.buildSearchRequest;
import static dev.langchain4j.store.embedding.milvus.Generator.generateRandomIds;
import static dev.langchain4j.store.embedding.milvus.Mapper.toEmbeddingMatches;
import static dev.langchain4j.store.embedding.milvus.Mapper.toMetadataJsons;
import static dev.langchain4j.store.embedding.milvus.Mapper.toScalars;
import static dev.langchain4j.store.embedding.milvus.Mapper.toVectors;
import static io.milvus.common.clientenum.ConsistencyLevelEnum.EVENTUALLY;
import static io.milvus.param.IndexType.FLAT;
import static io.milvus.param.MetricType.COSINE;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
/**
* Represents an <a href="https://milvus.io/">Milvus</a> index as an embedding store.
* <br>
* Supports both local and <a href="https://zilliz.com/">managed</a> Milvus instances.
* <br>
* Supports storing {@link Metadata} and filtering by it using a {@link Filter}
* (provided inside an {@link EmbeddingSearchRequest}).
*/
public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
static final String ID_FIELD_NAME = "id";
static final String TEXT_FIELD_NAME = "text";
static final String METADATA_FIELD_NAME = "metadata";
static final String VECTOR_FIELD_NAME = "vector";
private final MilvusServiceClient milvusClient;
private final String collectionName;
private final MetricType metricType;
private final ConsistencyLevelEnum consistencyLevel;
private final boolean retrieveEmbeddingsOnSearch;
private final boolean autoFlushOnInsert;
public MilvusEmbeddingStore(
String host,
Integer port,
String collectionName,
Integer dimension,
IndexType indexType,
MetricType metricType,
String uri,
String token,
String username,
String password,
ConsistencyLevelEnum consistencyLevel,
Boolean retrieveEmbeddingsOnSearch,
Boolean autoFlushOnInsert,
String databaseName
) {
ConnectParam.Builder connectBuilder = ConnectParam
.newBuilder()
.withHost(getOrDefault(host, "localhost"))
.withPort(getOrDefault(port, 19530))
.withUri(uri)
.withToken(token)
.withAuthorization(username, password);
if (databaseName != null) {
connectBuilder.withDatabaseName(databaseName);
}
this.milvusClient = new MilvusServiceClient(connectBuilder.build());
this.collectionName = getOrDefault(collectionName, "default");
this.metricType = getOrDefault(metricType, COSINE);
this.consistencyLevel = getOrDefault(consistencyLevel, EVENTUALLY);
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
if (!hasCollection(milvusClient, this.collectionName)) {
createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType);
}
loadCollectionInMemory(milvusClient, collectionName);
}
public void dropCollection(String collectionName) {
CollectionOperationsExecutor.dropCollection(milvusClient, collectionName);
}
public String add(Embedding embedding) {
String id = Utils.randomUUID();
add(id, embedding);
return id;
}
public void add(String id, Embedding embedding) {
addInternal(id, embedding, null);
}
public String add(Embedding embedding, TextSegment textSegment) {
String id = Utils.randomUUID();
addInternal(id, embedding, textSegment);
return id;
}
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = generateRandomIds(embeddings.size());
addAllInternal(ids, embeddings, null);
return ids;
}
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
List<String> ids = generateRandomIds(embeddings.size());
addAllInternal(ids, embeddings, embedded);
return ids;
}
@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
SearchParam searchParam = buildSearchRequest(
collectionName,
embeddingSearchRequest.queryEmbedding().vectorAsList(),
embeddingSearchRequest.filter(),
embeddingSearchRequest.maxResults(),
metricType,
consistencyLevel
);
SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam);
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(
milvusClient,
resultsWrapper,
collectionName,
consistencyLevel,
retrieveEmbeddingsOnSearch
);
List<EmbeddingMatch<TextSegment>> result = matches.stream()
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
.collect(toList());
return new EmbeddingSearchResult<>(result);
}
private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
addAllInternal(
singletonList(id),
singletonList(embedding),
textSegment == null ? null : singletonList(textSegment)
);
}
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field(ID_FIELD_NAME, ids));
fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size())));
fields.add(new InsertParam.Field(METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings)));
insert(milvusClient, collectionName, fields);
if (autoFlushOnInsert) {
flush(this.milvusClient, this.collectionName);
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String host;
private Integer port;
private String collectionName;
private Integer dimension;
private IndexType indexType;
private MetricType metricType;
private String uri;
private String token;
private String username;
private String password;
private ConsistencyLevelEnum consistencyLevel;
private Boolean retrieveEmbeddingsOnSearch;
private Boolean autoFlushOnInsert;
private String databaseName;
/**
* @param host The host of the self-managed Milvus instance.
* Default value: "localhost".
* @return builder
*/
public Builder host(String host) {
this.host = host;
return this;
}
/**
* @param port The port of the self-managed Milvus instance.
* Default value: 19530.
* @return builder
*/
public Builder port(Integer port) {
this.port = port;
return this;
}
/**
* @param collectionName The name of the Milvus collection.
* If there is no such collection yet, it will be created automatically.
* Default value: "default".
* @return builder
*/
public Builder collectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
/**
* @param dimension The dimension of the embedding vector. (e.g. 384)
* Mandatory if a new collection should be created.
* @return builder
*/
public Builder dimension(Integer dimension) {
this.dimension = dimension;
return this;
}
/**
* @param indexType The type of the index.
* Default value: FLAT.
* @return builder
*/
public Builder indexType(IndexType indexType) {
this.indexType = indexType;
return this;
}
/**
* @param metricType The type of the metric used for similarity search.
* Default value: COSINE.
* @return builder
*/
public Builder metricType(MetricType metricType) {
this.metricType = metricType;
return this;
}
/**
* @param uri The URI of the managed Milvus instance. (e.g. "https://xxx.api.gcp-us-west1.zillizcloud.com")
* @return builder
*/
public Builder uri(String uri) {
this.uri = uri;
return this;
}
/**
* @param token The token (API key) of the managed Milvus instance.
* @return builder
*/
public Builder token(String token) {
this.token = token;
return this;
}
/**
* @param username The username. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
* @return builder
*/
public Builder username(String username) {
this.username = username;
return this;
}
/**
* @param password The password. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
* @return builder
*/
public Builder password(String password) {
this.password = password;
return this;
}
/**
* @param consistencyLevel The consistency level used by Milvus.
* Default value: EVENTUALLY.
* @return builder
*/
public Builder consistencyLevel(ConsistencyLevelEnum consistencyLevel) {
this.consistencyLevel = consistencyLevel;
return this;
}
/**
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling findRelevant()),
* the embedding itself is not retrieved.
* To retrieve the embedding, an additional query is required.
* Setting this parameter to "true" will ensure that embedding is retrieved.
* Be aware that this will impact the performance of the search.
* Default value: false.
* @return builder
*/
public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) {
this.retrieveEmbeddingsOnSearch = retrieveEmbeddingsOnSearch;
return this;
}
/**
* @param autoFlushOnInsert Whether to automatically flush after each insert
* ({@code add(...)} or {@code addAll(...)} methods).
* Default value: false.
* More info can be found
* <a href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
* @return builder
*/
public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) {
this.autoFlushOnInsert = autoFlushOnInsert;
return this;
}
/**
* @param databaseName Milvus name of database.
* Default value: null. In this case default Milvus database name will be used.
* @return builder
*/
public Builder databaseName(String databaseName) {
this.databaseName = databaseName;
return this;
}
public MilvusEmbeddingStore build() {
return new MilvusEmbeddingStore(
host,
port,
collectionName,
dimension,
indexType,
metricType,
uri,
token,
username,
password,
consistencyLevel,
retrieveEmbeddingsOnSearch,
autoFlushOnInsert,
databaseName
);
}
}
}

View File

@@ -24,19 +24,19 @@ langchain4j:
embedding-model:
model-name: bge-small-zh
embedding-store:
persist-path: /tmp
# embedding-store:
# persist-path: /tmp
# chroma:
# embedding-store:
# baseUrl: http://0.0.0.0:8000
# timeout: 120s
# milvus:
# embedding-store:
# host: localhost
# port: 2379
# uri: http://0.0.0.0:19530
# token: demo
# dimension: 512
# timeout: 120s
milvus:
embedding-store:
host: localhost
port: 2379
uri: http://0.0.0.0:19530
token: demo
dimension: 512
timeout: 120s

View File

@@ -0,0 +1,15 @@
spring:
datasource:
driver-class-name: org.h2.Driver
schema: classpath:db/schema-h2.sql
data: classpath:db/data-h2.sql
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false
username: root
password: semantic
h2:
console:
path: /h2-console/semantic
enabled: true
config:
import:
- classpath:langchain4j-local.yaml

View File

@@ -13,7 +13,6 @@ spring:
config:
import:
- classpath:s2-config.yaml
- classpath:langchain4j-local.yaml
autoconfigure:
exclude:
- spring.dev.langchain4j.spring.LangChain4jAutoConfig
@@ -22,17 +21,6 @@ spring:
- spring.dev.langchain4j.azure.openai.spring.AutoConfig
- spring.dev.langchain4j.azure.aisearch.spring.AutoConfig
- spring.dev.langchain4j.anthropic.spring.AutoConfig
h2:
console:
path: /h2-console/semantic
enabled: true
datasource:
driver-class-name: org.h2.Driver
schema: classpath:db/schema-h2.sql
data: classpath:db/data-h2.sql
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false
username: root
password: semantic
mybatis:
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
@@ -40,4 +28,16 @@ mybatis:
logging:
level:
dev.langchain4j: DEBUG
dev.ai4j.openai4j: DEBUG
dev.ai4j.openai4j: DEBUG
swagger:
title: 'SuperSonic API Documentation'
base:
package: com.tencent.supersonic
description: 'SuperSonic API Documentation'
url: ''
contact:
name:
email:
url: ''
version: 3.0