(improvement)(headless) Add deduplication and persistence to InMemoryEmbeddingStore (#1256)

This commit is contained in:
lexluo09
2024-06-27 22:24:49 +08:00
committed by GitHub
parent 9d921dc47f
commit 391c0dccc8
14 changed files with 361 additions and 26 deletions

View File

@@ -1,5 +1,5 @@
name: SuperSonic enhancement
description: Add an enhanment for SuperSonic
name: SuperSonic enhancement request
description: Add an enhancement for SuperSonic
title: "[Enhancement] "
labels: enhancement

View File

@@ -1,5 +1,5 @@
name: SuperSonic question request
description: Ask a question of SuperSonic
description: Ask a question about SuperSonic
title: "[question] "
labels: question
body:

View File

@@ -260,6 +260,11 @@
<artifactId>spring-boot-autoconfigure-processor</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
</dependency>
</dependencies>
</project>

View File

@@ -7,5 +7,5 @@ import lombok.Setter;
@Setter
class EmbeddingStoreProperties {
private String filePath;
private String persistPath;
}

View File

@@ -1,8 +1,6 @@
package dev.langchain4j.inmemory.spring;
import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
@@ -14,6 +12,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.inmemory.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class InMemoryAutoConfig {
@@ -22,8 +22,8 @@ public class InMemoryAutoConfig {
public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q";
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.file-path")
EmbeddingStoreFactory milvusChatModel(Properties properties) {
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
return new InMemoryEmbeddingStoreFactory(properties);
}

View File

@@ -1,18 +1,27 @@
package dev.langchain4j.inmemory.spring;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
@Slf4j
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>();
private Properties properties;
@@ -28,9 +37,63 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
if (Objects.nonNull(embeddingStore)) {
return embeddingStore;
}
embeddingStore = new InMemoryEmbeddingStore();
embeddingStore = reloadFromPersistFile(collectionName);
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
}
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
return embeddingStore;
}
private InMemoryEmbeddingStore<TextSegment> reloadFromPersistFile(String collectionName) {
Path filePath = getPersistPath(collectionName);
if (Objects.isNull(filePath)) {
return null;
}
InMemoryEmbeddingStore<TextSegment> embeddingStore = null;
try {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
log.info("embeddingStore reload from file:{}", filePath);
}
} catch (Exception e) {
log.error("load persistFile error, persistFile:" + filePath, e);
}
return embeddingStore;
}
public synchronized void persistFile() {
if (MapUtils.isEmpty(collectionNameToStore)) {
return;
}
for (Map.Entry<String, InMemoryEmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
Path filePath = getPersistPath(entry.getKey());
if (Objects.isNull(filePath)) {
continue;
}
try {
Path directoryPath = filePath.getParent();
if (!Files.exists(directoryPath)) {
Files.createDirectories(directoryPath);
}
entry.getValue().serializeToFile(filePath);
} catch (Exception e) {
log.error("persistFile error, persistFile:" + filePath, e);
}
}
}
private Path getPersistPath(String collectionName) {
String persistFile = PERSISTENT_FILE_PRE + collectionName;
String persistPath = properties.getEmbeddingStore().getPersistPath();
if (StringUtils.isEmpty(persistPath)) {
return null;
}
return Paths.get(persistPath, persistFile);
}
}

View File

@@ -0,0 +1,250 @@
package dev.langchain4j.store.embedding.inmemory;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.stream.IntStream;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
import static java.util.Comparator.comparingDouble;
import static java.util.stream.Collectors.toList;
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
* <p>
* This store can be persisted using the {@link #serializeToJson()} and {@link #serializeToFile(Path)} methods.
* <p>
* It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link #fromFile(Path)} methods.
*
* @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
*/
public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
public Set<Entry<Embedded>> entries = new CopyOnWriteArraySet<>();
@Override
public String add(Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}
@Override
public void add(String id, Embedding embedding) {
add(id, embedding, null);
}
@Override
public String add(Embedding embedding, Embedded embedded) {
String id = randomUUID();
add(id, embedding, embedded);
return id;
}
public void add(String id, Embedding embedding, Embedded embedded) {
entries.add(new Entry<>(id, embedding, embedded));
}
private List<String> add(List<Entry<Embedded>> newEntries) {
entries.addAll(newEntries);
return newEntries.stream()
.map(entry -> entry.id)
.collect(toList());
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<Entry<Embedded>> newEntries = embeddings.stream()
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding))
.collect(toList());
return add(newEntries);
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
if (embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
}
List<Entry<Embedded>> newEntries = IntStream.range(0, embeddings.size())
.mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(toList());
return add(newEntries);
}
@Override
public void removeAll(Collection<String> ids) {
ensureNotEmpty(ids, "ids");
entries.removeIf(entry -> ids.contains(entry.id));
}
@Override
public void removeAll(Filter filter) {
ensureNotNull(filter, "filter");
entries.removeIf(entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
}
@Override
public void removeAll() {
entries.clear();
}
@Override
public EmbeddingSearchResult<Embedded> search(EmbeddingSearchRequest embeddingSearchRequest) {
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
Filter filter = embeddingSearchRequest.filter();
for (Entry<Embedded> entry : entries) {
if (filter != null && entry.embedded instanceof TextSegment) {
Metadata metadata = ((TextSegment) entry.embedded).metadata();
if (!filter.test(metadata)) {
continue;
}
}
double cosineSimilarity = CosineSimilarity.between(entry.embedding,
embeddingSearchRequest.queryEmbedding());
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= embeddingSearchRequest.minScore()) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
if (matches.size() > embeddingSearchRequest.maxResults()) {
matches.poll();
}
}
}
List<EmbeddingMatch<Embedded>> result = new ArrayList<>(matches);
result.sort(comparator);
Collections.reverse(result);
return new EmbeddingSearchResult<>(result);
}
public String serializeToJson() {
return loadCodec().toJson(this);
}
public void serializeToFile(Path filePath) {
try {
String json = serializeToJson();
Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public void serializeToFile(String filePath) {
serializeToFile(Paths.get(filePath));
}
public static InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
return loadCodec().fromJson(json);
}
public static InMemoryEmbeddingStore<TextSegment> fromFile(Path filePath) {
try {
String json = new String(Files.readAllBytes(filePath));
return fromJson(json);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static InMemoryEmbeddingStore<TextSegment> fromFile(String filePath) {
return fromFile(Paths.get(filePath));
}
private static class Entry<Embedded> {
String id;
Embedding embedding;
Embedded embedded;
Entry(String id, Embedding embedding) {
this(id, embedding, null);
}
Entry(String id, Embedding embedding, Embedded embedded) {
this.id = ensureNotBlank(id, "id");
this.embedding = ensureNotNull(embedding, "embedding");
this.embedded = embedded;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Entry<?> that = (Entry<?>) o;
return Objects.equals(this.id, that.id)
&& Objects.equals(this.embedding, that.embedding)
&& Objects.equals(this.embedded, that.embedded);
}
@Override
public int hashCode() {
return Objects.hash(id, embedding, embedded);
}
}
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
for (InMemoryEmbeddingStoreJsonCodecFactory factory :
loadFactories(InMemoryEmbeddingStoreJsonCodecFactory.class)) {
return factory.create();
}
return new GsonInMemoryEmbeddingStoreJsonCodec();
}
}

View File

@@ -42,7 +42,6 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
return;
}
sleep();
embeddingService.addCollection(embeddingConfig.getMetaCollectionName());
if (event.getEventType().equals(EventType.ADD)) {
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments);
} else if (event.getEventType().equals(EventType.DELETE)) {

View File

@@ -5,6 +5,8 @@ import com.tencent.supersonic.common.pojo.DataItem;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.headless.server.web.service.DimensionService;
import com.tencent.supersonic.headless.server.web.service.MetricService;
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
@@ -28,28 +30,33 @@ public class EmbeddingTask {
@Autowired
private DimensionService dimensionService;
@Autowired
private EmbeddingStoreFactory embeddingStoreFactory;
@PreDestroy
public void onShutdown() {
// embeddingStorePersistentToFile();
embeddingStorePersistFile();
}
// private void embeddingStorePersistentToFile() {
// if (embeddingService instanceof InMemoryEmbeddingService) {
// log.info("start persistentToFile");
// ((InMemoryEmbeddingService) embeddingService).persistentToFile();
// log.info("end persistentToFile");
// }
// }
private void embeddingStorePersistFile() {
if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) {
log.info("start persistFile");
InMemoryEmbeddingStoreFactory inMemoryFactory =
(InMemoryEmbeddingStoreFactory) embeddingStoreFactory;
inMemoryFactory.persistFile();
log.info("end persistFile");
}
}
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
public void executeTask() {
// embeddingStorePersistentToFile();
@Scheduled(cron = "${s2.inMemoryEmbeddingStore.persist.cron:0 0 * * * ?}")
public void executePersistFileTask() {
embeddingStorePersistFile();
}
/***
* reload meta embedding
*/
@Scheduled(cron = "${reload.meta.embedding.corn:0 0 */2 * * ?}")
@Scheduled(cron = "${s2.reload.meta.embedding.corn:0 0 */2 * * ?}")
public void reloadMetaEmbedding() {
log.info("reload.meta.embedding start");
try {

View File

@@ -141,6 +141,12 @@ public class KnowledgeController {
return true;
}
@GetMapping("/embedding/persistFile")
public Object executePersistFileTask() {
embeddingTask.executePersistFileTask();
return true;
}
/**
* queryDictValue-返回字典的数据
*
@@ -161,8 +167,8 @@ public class KnowledgeController {
*/
@PostMapping("/dict/file")
public String queryDictFilePath(@RequestBody @Valid DictValueReq dictValueReq,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return taskService.queryDictFilePath(dictValueReq, user);
}

View File

@@ -13,4 +13,4 @@ langchain4j:
embedding-model:
model-name: bge-small-zh
embedding-store:
file-path: /tmp
persist-path: /tmp

View File

@@ -13,4 +13,4 @@ langchain4j:
embedding-model:
model-name: bge-small-zh
embedding-store:
file-path: /tmp
persist-path: /tmp

View File

@@ -215,6 +215,11 @@
<artifactId>spring-boot-autoconfigure-processor</artifactId>
<version>${spring.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>${gson.version}</version>
</dependency>
</dependencies>
</dependencyManagement>