(improvement)(chat) Upgrade and optimize the embedding metastore. (#1198)

This commit is contained in:
lexluo09
2024-06-23 21:46:10 +08:00
committed by GitHub
parent 2ae94fb38c
commit 4d6cbf31f7
46 changed files with 3788 additions and 498 deletions

View File

@@ -0,0 +1,11 @@
package dev.langchain4j.inmemory.spring;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
class EmbeddingStoreProperties {
private String filePath;
}

View File

@@ -0,0 +1,21 @@
package dev.langchain4j.inmemory.spring;
import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class InMemoryAutoConfig {
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.file-path")
EmbeddingStoreFactory milvusChatModel(Properties properties) {
return new InMemoryEmbeddingStoreFactory(properties);
}
}

View File

@@ -0,0 +1,37 @@
package dev.langchain4j.inmemory.spring;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
new ConcurrentHashMap<>();
private Properties properties;
public InMemoryEmbeddingStoreFactory(Properties properties) {
this.properties = properties;
}
@Override
public synchronized EmbeddingStore create(String collectionName) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
if (Objects.nonNull(embeddingStore)) {
return embeddingStore;
}
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
}
return embeddingStore;
}
}

View File

@@ -0,0 +1,17 @@
package dev.langchain4j.inmemory.spring;
import lombok.Getter;
import lombok.Setter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
@Getter
@Setter
@ConfigurationProperties(prefix = Properties.PREFIX)
public class Properties {
static final String PREFIX = "langchain4j.in-memory";
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}

View File

@@ -1,23 +0,0 @@
package dev.langchain4j.store.embedding;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.Objects;
public class ComponentFactory {
private static S2EmbeddingStore s2EmbeddingStore;
public static S2EmbeddingStore getS2EmbeddingStore() {
if (Objects.isNull(s2EmbeddingStore)) {
s2EmbeddingStore = init(S2EmbeddingStore.class);
}
return s2EmbeddingStore;
}
private static <T> T init(Class<T> factoryType) {
return SpringFactoriesLoader.loadFactories(factoryType,
Thread.currentThread().getContextClassLoader()).get(0);
}
}

View File

@@ -19,7 +19,6 @@ public class EmbeddingQuery {
private Map<String, Object> metadata;
private List<Double> queryEmbedding;
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();

View File

@@ -0,0 +1,6 @@
package dev.langchain4j.store.embedding;
public interface EmbeddingStoreFactory {
EmbeddingStore create(String collectionName);
}

View File

@@ -1,21 +0,0 @@
package dev.langchain4j.store.embedding;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.lang.reflect.Type;
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
@Override
public InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
Type type = new TypeToken<InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<EmbeddingQuery>>() {
}.getType();
return new Gson().fromJson(json, type);
}
@Override
public String toJson(InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<?> store) {
return new Gson().toJson(store);
}
}

View File

@@ -1,10 +0,0 @@
package dev.langchain4j.store.embedding;
import dev.langchain4j.store.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
public interface InMemoryEmbeddingStoreJsonCodec {
InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json);
String toJson(InMemoryEmbeddingStore<?> store);
}

View File

@@ -1,363 +0,0 @@
package dev.langchain4j.store.embedding;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.randomUUID;
import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
import static java.util.Comparator.comparingDouble;
/***
* Implementation of S2EmbeddingStore within the Java process's in-memory.
*/
@Slf4j
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
new ConcurrentHashMap<>();
@Override
public synchronized void addCollection(String collectionName) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = null;
Path filePath = getPersistentPath(collectionName);
try {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
log.info("embeddingStore reload from file:{}", filePath);
}
} catch (Exception e) {
log.error("load persistentFile error, persistentFile:" + filePath, e);
}
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
}
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
}
private Path getPersistentPath(String collectionName) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String persistentFile = PERSISTENT_FILE_PRE + collectionName;
return Paths.get(embeddingConfig.getEmbeddingStorePersistentPath(), persistentFile);
}
public void persistentToFile() {
for (Entry<String, InMemoryEmbeddingStore<EmbeddingQuery>> entry : collectionNameToStore.entrySet()) {
Path filePath = getPersistentPath(entry.getKey());
try {
Path directoryPath = filePath.getParent();
if (!Files.exists(directoryPath)) {
Files.createDirectories(directoryPath);
}
entry.getValue().serializeToFile(filePath);
} catch (Exception e) {
log.error("persistentToFile error, persistentFile:" + filePath, e);
}
}
}
@Override
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
EmbeddingModel embeddingModel = getEmbeddingModel();
for (EmbeddingQuery query : queries) {
String question = query.getQuery();
Embedding embedding = embeddingModel.embed(question).content();
embeddingStore.add(query.getQueryId(), embedding, query);
}
}
private static EmbeddingModel getEmbeddingModel() {
EmbeddingModel embeddingModel;
try {
embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
} catch (NoSuchBeanDefinitionException e) {
embeddingModel = new BgeSmallZhEmbeddingModel();
}
return embeddingModel;
}
private InMemoryEmbeddingStore<EmbeddingQuery> getEmbeddingStore(String collectionName) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
if (Objects.isNull(embeddingStore)) {
synchronized (InMemoryS2EmbeddingStore.class) {
addCollection(collectionName);
embeddingStore = collectionNameToStore.get(collectionName);
}
}
return embeddingStore;
}
@Override
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
//not support in InMemoryEmbeddingStore
}
@Override
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
EmbeddingModel embeddingModel = getEmbeddingModel();
List<RetrieveQueryResult> results = new ArrayList<>();
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) {
Embedding embeddedText = embeddingModel.embed(queryText).content();
int maxResults = getMaxResults(num, filterCondition);
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, maxResults);
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText);
List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
Retrieval retrieval = new Retrieval();
retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(embeddingMatch.embeddingId());
retrieval.setQuery(embeddingMatch.embedded().getQuery());
Map<String, Object> metadata = new HashMap<>();
if (Objects.nonNull(embeddingMatch.embedded())
&& MapUtils.isNotEmpty(embeddingMatch.embedded().getMetadata())) {
metadata.putAll(embeddingMatch.embedded().getMetadata());
}
if (filterRetrieval(filterCondition, metadata)) {
continue;
}
retrieval.setMetadata(metadata);
retrievals.add(retrieval);
}
retrievals = retrievals.stream()
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
.limit(num)
.collect(Collectors.toList());
retrieveQueryResult.setRetrieval(retrievals);
results.add(retrieveQueryResult);
}
return results;
}
private int getMaxResults(int num, Map<String, String> filterCondition) {
int maxResults = num;
if (MapUtils.isNotEmpty(filterCondition)) {
maxResults = num * 5;
}
return maxResults;
}
private boolean filterRetrieval(Map<String, String> filterCondition, Map<String, Object> metadata) {
if (MapUtils.isNotEmpty(metadata) && MapUtils.isNotEmpty(filterCondition)) {
for (Entry<String, Object> entry : metadata.entrySet()) {
String filterValue = filterCondition.get(entry.getKey());
if (StringUtils.isNotBlank(filterValue) && !filterValue.equalsIgnoreCase(
entry.getValue().toString())) {
return true;
}
}
}
return false;
}
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
*
* @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
* and fix concurrentModificationException in a multi-threaded environment
*/
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
private static class Entry<Embedded> {
String id;
Embedding embedding;
Embedded embedded;
Entry(String id, Embedding embedding, Embedded embedded) {
this.id = id;
this.embedding = embedding;
this.embedded = embedded;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
InMemoryEmbeddingStore.Entry<?> that = (InMemoryEmbeddingStore.Entry<?>) o;
return Objects.equals(this.id, that.id)
&& Objects.equals(this.embedding, that.embedding)
&& Objects.equals(this.embedded, that.embedded);
}
@Override
public int hashCode() {
return Objects.hash(id, embedding, embedded);
}
}
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
private Set<Entry<Embedded>> entries = new CopyOnWriteArraySet<>();
@Override
public String add(Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}
@Override
public void add(String id, Embedding embedding) {
add(id, embedding, null);
}
@Override
public String add(Embedding embedding, Embedded embedded) {
String id = randomUUID();
add(id, embedding, embedded);
return id;
}
public void add(String id, Embedding embedding, Embedded embedded) {
entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded));
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = new ArrayList<>();
for (Embedding embedding : embeddings) {
ids.add(add(embedding));
}
return ids;
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
if (embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
}
List<String> ids = new ArrayList<>();
for (int i = 0; i < embeddings.size(); i++) {
ids.add(add(embeddings.get(i), embedded.get(i)));
}
return ids;
}
@Override
public List<EmbeddingMatch<Embedded>> findRelevant(Embedding referenceEmbedding, int maxResults,
double minScore) {
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
for (InMemoryEmbeddingStore.Entry<Embedded> entry : entries) {
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= minScore) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
if (matches.size() > maxResults) {
matches.poll();
}
}
}
List<EmbeddingMatch<Embedded>> result = new ArrayList<>(matches);
result.sort(comparator);
Collections.reverse(result);
return result;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
InMemoryEmbeddingStore<?> that = (InMemoryEmbeddingStore<?>) o;
return Objects.equals(this.entries, that.entries);
}
@Override
public int hashCode() {
return Objects.hash(entries);
}
public String serializeToJson() {
return CODEC.toJson(this);
}
public void serializeToFile(Path filePath) {
try {
String json = serializeToJson();
Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public void serializeToFile(String filePath) {
serializeToFile(Paths.get(filePath));
}
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
// fallback to default
return new GsonInMemoryEmbeddingStoreJsonCodec();
}
public static InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
return CODEC.fromJson(json);
}
public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(Path filePath) {
try {
String json = new String(Files.readAllBytes(filePath));
return fromJson(json);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(String filePath) {
return fromFile(Paths.get(filePath));
}
}
}

View File

@@ -4,6 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.ParameterizedTypeReference;
@@ -25,7 +26,7 @@ import java.util.stream.Collectors;
* Implementation of calling the Python service S2EmbeddingStore.
*/
@Slf4j
public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
public class PythonServiceEmbeddingService implements EmbeddingService {
private RestTemplate restTemplate = new RestTemplate();

View File

@@ -1,19 +0,0 @@
package dev.langchain4j.store.embedding;
import java.util.List;
/**
* Supersonic EmbeddingStore
* Enhanced the functionality by enabling the addition and querying of collection names.
*/
public interface S2EmbeddingStore {
void addCollection(String collectionName);
void addQuery(String collectionName, List<EmbeddingQuery> queries);
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
}