(improvement)(build) Add spotless during the build process. (#1639)

This commit is contained in:
lexluo09
2024-09-07 00:36:17 +08:00
committed by GitHub
parent ee15a88b06
commit 5f59e89eea
986 changed files with 15609 additions and 12706 deletions

View File

@@ -6,11 +6,12 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
protected final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
protected final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>();
public EmbeddingStore<TextSegment> create(String collectionName) {
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);
}
public abstract EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName);
}
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.store.embedding;
import lombok.Data;
import java.util.Map;
@@ -13,5 +12,4 @@ public class EmbeddingCollection {
private String name;
private Map<String, String> metaData;
}

View File

@@ -5,5 +5,4 @@ import dev.langchain4j.data.segment.TextSegment;
public interface EmbeddingStoreFactory {
EmbeddingStore<TextSegment> create(String collectionName);
}
}

View File

@@ -12,29 +12,39 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class EmbeddingStoreFactoryProvider {
protected static final Map<EmbeddingStoreConfig, EmbeddingStoreFactory> factoryMap = new ConcurrentHashMap<>();
protected static final Map<EmbeddingStoreConfig, EmbeddingStoreFactory> factoryMap =
new ConcurrentHashMap<>();
public static EmbeddingStoreFactory getFactory() {
EmbeddingStoreParameterConfig parameterConfig = ContextUtils.getBean(EmbeddingStoreParameterConfig.class);
EmbeddingStoreParameterConfig parameterConfig =
ContextUtils.getBean(EmbeddingStoreParameterConfig.class);
return getFactory(parameterConfig.convert());
}
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig embeddingStoreConfig) {
if (embeddingStoreConfig == null || StringUtils.isBlank(embeddingStoreConfig.getProvider())) {
if (embeddingStoreConfig == null
|| StringUtils.isBlank(embeddingStoreConfig.getProvider())) {
return ContextUtils.getBean(EmbeddingStoreFactory.class);
}
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig));
}
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
}
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
if (EmbeddingStoreType.IN_MEMORY
.name()
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
}
throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: " + embeddingStoreConfig.getProvider());
throw new RuntimeException(
"Unsupported EmbeddingStoreFactory provider: "
+ embeddingStoreConfig.getProvider());
}
}
}

View File

@@ -35,8 +35,9 @@ public class Retrieval {
return false;
}
Retrieval retrieval = (Retrieval) o;
return Double.compare(retrieval.distance, distance) == 0 && Objects.equal(id,
retrieval.id) && Objects.equal(query, retrieval.query)
return Double.compare(retrieval.distance, distance) == 0
&& Objects.equal(id, retrieval.id)
&& Objects.equal(query, retrieval.query)
&& Objects.equal(metadata, retrieval.metadata);
}

View File

@@ -15,6 +15,4 @@ public class RetrieveQuery {
private Map<String, Object> filterCondition;
private List<List<Double>> queryEmbeddings;
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.store.embedding;
import lombok.Data;
import java.util.List;
@@ -11,5 +10,4 @@ public class RetrieveQueryResult {
private String query;
private List<Retrieval> retrieval;
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import dev.langchain4j.data.document.Metadata;
@@ -18,12 +17,20 @@ public class TextSegmentConvert {
public static final String QUERY_ID = "queryId";
public static List<TextSegment> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta));
addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase());
return textSegment;
}).collect(Collectors.toList());
return dataItems.stream()
.map(
dataItem -> {
Map meta =
JSONObject.parseObject(
JSONObject.toJSONString(dataItem), Map.class);
TextSegment textSegment =
TextSegment.from(dataItem.getName(), new Metadata(meta));
addQueryId(
textSegment,
dataItem.getId() + dataItem.getType().name().toLowerCase());
return textSegment;
})
.collect(Collectors.toList());
}
public static void addQueryId(TextSegment textSegment, String queryId) {

View File

@@ -39,15 +39,17 @@ import static java.util.stream.Collectors.toList;
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
* <p>
* This store can be persisted using the {@link #serializeToJson()} and {@link #serializeToFile(Path)} methods.
* <p>
* It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link #fromFile(Path)} methods.
*
* @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
* <p>Uses a brute force approach by iterating over all embeddings to find the best matches.
*
* <p>This store can be persisted using the {@link #serializeToJson()} and {@link
* #serializeToFile(Path)} methods.
*
* <p>It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link
* #fromFile(Path)} methods.
*
* @param <Embedded> The class of the object that has been embedded. Typically, it is {@link
* dev.langchain4j.data.segment.TextSegment}.
*/
public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
@@ -80,17 +82,16 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
entries.addAll(newEntries);
return newEntries.stream()
.map(entry -> entry.id)
.collect(toList());
return newEntries.stream().map(entry -> entry.id).collect(toList());
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<Entry<Embedded>> newEntries = embeddings.stream()
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding))
.collect(toList());
List<Entry<Embedded>> newEntries =
embeddings.stream()
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding))
.collect(toList());
return add(newEntries);
}
@@ -98,12 +99,15 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
@Override
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
if (embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
throw new IllegalArgumentException(
"The list of embeddings and embedded must have the same size");
}
List<Entry<Embedded>> newEntries = IntStream.range(0, embeddings.size())
.mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(toList());
List<Entry<Embedded>> newEntries =
IntStream.range(0, embeddings.size())
.mapToObj(
i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(toList());
return add(newEntries);
}
@@ -119,15 +123,16 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
public void removeAll(Filter filter) {
ensureNotNull(filter, "filter");
entries.removeIf(entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
entries.removeIf(
entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
}
@Override
@@ -152,8 +157,9 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
}
}
double cosineSimilarity = CosineSimilarity.between(entry.embedding,
embeddingSearchRequest.queryEmbedding());
double cosineSimilarity =
CosineSimilarity.between(
entry.embedding, embeddingSearchRequest.queryEmbedding());
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= embeddingSearchRequest.minScore()) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));

View File

@@ -42,12 +42,10 @@ import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
/**
* Represents an <a href="https://milvus.io/">Milvus</a> index as an embedding store.
* <br>
* Supports both local and <a href="https://zilliz.com/">managed</a> Milvus instances.
* <br>
* Supports storing {@link Metadata} and filtering by it using a {@link Filter}
* (provided inside an {@link EmbeddingSearchRequest}).
* Represents an <a href="https://milvus.io/">Milvus</a> index as an embedding store. <br>
* Supports both local and <a href="https://zilliz.com/">managed</a> Milvus instances. <br>
* Supports storing {@link Metadata} and filtering by it using a {@link Filter} (provided inside an
* {@link EmbeddingSearchRequest}).
*/
public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
@@ -78,15 +76,14 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
ConsistencyLevelEnum consistencyLevel,
Boolean retrieveEmbeddingsOnSearch,
Boolean autoFlushOnInsert,
String databaseName
) {
ConnectParam.Builder connectBuilder = ConnectParam
.newBuilder()
.withHost(getOrDefault(host, "localhost"))
.withPort(getOrDefault(port, 19530))
.withUri(uri)
.withToken(token)
.withAuthorization(username, password);
String databaseName) {
ConnectParam.Builder connectBuilder =
ConnectParam.newBuilder()
.withHost(getOrDefault(host, "localhost"))
.withPort(getOrDefault(port, 19530))
.withUri(uri)
.withToken(token)
.withAuthorization(username, password);
if (databaseName != null) {
connectBuilder.withDatabaseName(databaseName);
@@ -99,8 +96,13 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
if (!hasCollection(milvusClient, this.collectionName)) {
createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType);
createCollection(
milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
createIndex(
milvusClient,
this.collectionName,
getOrDefault(indexType, FLAT),
this.metricType);
}
loadCollectionInMemory(milvusClient, collectionName);
@@ -139,30 +141,33 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
public EmbeddingSearchResult<TextSegment> search(
EmbeddingSearchRequest embeddingSearchRequest) {
SearchParam searchParam = buildSearchRequest(
collectionName,
embeddingSearchRequest.queryEmbedding().vectorAsList(),
embeddingSearchRequest.filter(),
embeddingSearchRequest.maxResults(),
metricType,
consistencyLevel
);
SearchParam searchParam =
buildSearchRequest(
collectionName,
embeddingSearchRequest.queryEmbedding().vectorAsList(),
embeddingSearchRequest.filter(),
embeddingSearchRequest.maxResults(),
metricType,
consistencyLevel);
SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam);
SearchResultsWrapper resultsWrapper =
CollectionOperationsExecutor.search(milvusClient, searchParam);
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(
milvusClient,
resultsWrapper,
collectionName,
consistencyLevel,
retrieveEmbeddingsOnSearch
);
List<EmbeddingMatch<TextSegment>> matches =
toEmbeddingMatches(
milvusClient,
resultsWrapper,
collectionName,
consistencyLevel,
retrieveEmbeddingsOnSearch);
List<EmbeddingMatch<TextSegment>> result = matches.stream()
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
.collect(toList());
List<EmbeddingMatch<TextSegment>> result =
matches.stream()
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
.collect(toList());
return new EmbeddingSearchResult<>(result);
}
@@ -171,15 +176,17 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
addAllInternal(
singletonList(id),
singletonList(embedding),
textSegment == null ? null : singletonList(textSegment)
);
textSegment == null ? null : singletonList(textSegment));
}
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
private void addAllInternal(
List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field(ID_FIELD_NAME, ids));
fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size())));
fields.add(new InsertParam.Field(METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
fields.add(
new InsertParam.Field(
METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings)));
insert(milvusClient, collectionName, fields);
@@ -210,8 +217,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
private String databaseName;
/**
* @param host The host of the self-managed Milvus instance.
* Default value: "localhost".
* @param host The host of the self-managed Milvus instance. Default value: "localhost".
* @return builder
*/
public Builder host(String host) {
@@ -220,8 +226,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param port The port of the self-managed Milvus instance.
* Default value: 19530.
* @param port The port of the self-managed Milvus instance. Default value: 19530.
* @return builder
*/
public Builder port(Integer port) {
@@ -230,9 +235,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param collectionName The name of the Milvus collection.
* If there is no such collection yet, it will be created automatically.
* Default value: "default".
* @param collectionName The name of the Milvus collection. If there is no such collection
* yet, it will be created automatically. Default value: "default".
* @return builder
*/
public Builder collectionName(String collectionName) {
@@ -241,8 +245,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param dimension The dimension of the embedding vector. (e.g. 384)
* Mandatory if a new collection should be created.
* @param dimension The dimension of the embedding vector. (e.g. 384) Mandatory if a new
* collection should be created.
* @return builder
*/
public Builder dimension(Integer dimension) {
@@ -251,8 +255,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param indexType The type of the index.
* Default value: FLAT.
* @param indexType The type of the index. Default value: FLAT.
* @return builder
*/
public Builder indexType(IndexType indexType) {
@@ -261,8 +264,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param metricType The type of the metric used for similarity search.
* Default value: COSINE.
* @param metricType The type of the metric used for similarity search. Default value:
* COSINE.
* @return builder
*/
public Builder metricType(MetricType metricType) {
@@ -271,7 +274,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param uri The URI of the managed Milvus instance. (e.g. "https://xxx.api.gcp-us-west1.zillizcloud.com")
* @param uri The URI of the managed Milvus instance. (e.g.
* "https://xxx.api.gcp-us-west1.zillizcloud.com")
* @return builder
*/
public Builder uri(String uri) {
@@ -289,7 +293,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param username The username. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
* @param username The username. See details <a
* href="https://milvus.io/docs/authenticate.md">here</a>.
* @return builder
*/
public Builder username(String username) {
@@ -298,7 +303,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param password The password. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
* @param password The password. See details <a
* href="https://milvus.io/docs/authenticate.md">here</a>.
* @return builder
*/
public Builder password(String password) {
@@ -307,8 +313,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param consistencyLevel The consistency level used by Milvus.
* Default value: EVENTUALLY.
* @param consistencyLevel The consistency level used by Milvus. Default value: EVENTUALLY.
* @return builder
*/
public Builder consistencyLevel(ConsistencyLevelEnum consistencyLevel) {
@@ -317,12 +322,11 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling findRelevant()),
* the embedding itself is not retrieved.
* To retrieve the embedding, an additional query is required.
* Setting this parameter to "true" will ensure that embedding is retrieved.
* Be aware that this will impact the performance of the search.
* Default value: false.
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling
* findRelevant()), the embedding itself is not retrieved. To retrieve the embedding, an
* additional query is required. Setting this parameter to "true" will ensure that
* embedding is retrieved. Be aware that this will impact the performance of the search.
* Default value: false.
* @return builder
*/
public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) {
@@ -331,11 +335,10 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param autoFlushOnInsert Whether to automatically flush after each insert
* ({@code add(...)} or {@code addAll(...)} methods).
* Default value: false.
* More info can be found
* <a href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
* @param autoFlushOnInsert Whether to automatically flush after each insert ({@code
* add(...)} or {@code addAll(...)} methods). Default value: false. More info can be
* found <a
* href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
* @return builder
*/
public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) {
@@ -344,8 +347,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param databaseName Milvus name of database.
* Default value: null. In this case default Milvus database name will be used.
* @param databaseName Milvus name of database. Default value: null. In this case default
* Milvus database name will be used.
* @return builder
*/
public Builder databaseName(String databaseName) {
@@ -368,8 +371,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
consistencyLevel,
retrieveEmbeddingsOnSearch,
autoFlushOnInsert,
databaseName
);
databaseName);
}
}
}