[improvement](chat) Add persistence to InMemoryEmbeddingStore, fix the issue of PythonServiceS2EmbeddingStore being empty. (#524)

This commit is contained in:
lexluo09
2023-12-17 11:04:29 +08:00
committed by GitHub
parent 3db443f9b1
commit fe75b3e393
6 changed files with 94 additions and 20 deletions

View File

@@ -37,7 +37,7 @@ public class JavaLLMProxy implements LLMProxy {
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode()));
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
String modelName = llmReq.getSchema().getModelName();
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);

View File

@@ -69,5 +69,14 @@ public class LLMReq {
return name;
}
public static SqlGenerationMode getMode(String name) {
for (SqlGenerationMode sqlGenerationMode : SqlGenerationMode.values()) {
if (sqlGenerationMode.name.equals(name)) {
return sqlGenerationMode;
}
}
return null;
}
}
}

View File

@@ -0,0 +1,22 @@
package com.tencent.supersonic.common.util.embedding;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
import dev.langchain4j.data.segment.TextSegment;
import java.lang.reflect.Type;
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
@Override
public InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
Type type = new TypeToken<InMemoryEmbeddingStore<TextSegment>>() {
}.getType();
return new Gson().fromJson(json, type);
}
@Override
public String toJson(InMemoryEmbeddingStore<?> store) {
return new Gson().toJson(store);
}
}

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.common.util.embedding;
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
import dev.langchain4j.data.segment.TextSegment;
public interface InMemoryEmbeddingStoreJsonCodec {
InMemoryEmbeddingStore<TextSegment> fromJson(String json);
String toJson(InMemoryEmbeddingStore<?> store);
}

View File

@@ -1,21 +1,23 @@
package com.tencent.supersonic.common.util.embedding;
import static dev.langchain4j.internal.Utils.randomUUID;
import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
import static java.util.Comparator.comparingDouble;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.spi.ServiceHelper;
import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.inmemory.GsonInMemoryEmbeddingStoreJsonCodec;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodec;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
@@ -40,7 +42,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
new ConcurrentHashMap<>();
@Override
public void addCollection(String collectionName) {
public synchronized void addCollection(String collectionName) {
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
}
@@ -164,7 +166,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
if (o == null || getClass() != o.getClass()) {
return false;
}
Entry<?> that = (Entry<?>) o;
InMemoryEmbeddingStore.Entry<?> that = (InMemoryEmbeddingStore.Entry<?>) o;
return Objects.equals(this.id, that.id)
&& Objects.equals(this.embedding, that.embedding)
&& Objects.equals(this.embedded, that.embedded);
@@ -176,7 +178,8 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
}
}
private final List<Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
private final List<InMemoryEmbeddingStore.Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
@Override
public String add(Embedding embedding) {
@@ -198,7 +201,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
}
public void add(String id, Embedding embedding, Embedded embedded) {
entries.add(new Entry<>(id, embedding, embedded));
entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded));
}
@Override
@@ -230,7 +233,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
for (Entry<Embedded> entry : entries) {
for (InMemoryEmbeddingStore.Entry<Embedded> entry : entries) {
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= minScore) {
@@ -264,16 +267,44 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
return Objects.hash(entries);
}
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
Collection<InMemoryEmbeddingStoreJsonCodecFactory> factories = ServiceHelper.loadFactories(
InMemoryEmbeddingStoreJsonCodecFactory.class);
for (InMemoryEmbeddingStoreJsonCodecFactory factory : factories) {
return factory.create();
public String serializeToJson() {
return CODEC.toJson(this);
}
public void serializeToFile(Path filePath) {
try {
String json = serializeToJson();
Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public void serializeToFile(String filePath) {
serializeToFile(Paths.get(filePath));
}
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
// fallback to default
return new GsonInMemoryEmbeddingStoreJsonCodec();
}
public static InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
return CODEC.fromJson(json);
}
public static InMemoryEmbeddingStore<TextSegment> fromFile(Path filePath) {
try {
String json = new String(Files.readAllBytes(filePath));
return fromJson(json);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static InMemoryEmbeddingStore<TextSegment> fromFile(String filePath) {
return fromFile(Paths.get(filePath));
}
}
}

View File

@@ -4,12 +4,12 @@ import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
@@ -26,9 +26,6 @@ import org.springframework.web.util.UriComponentsBuilder;
@Slf4j
public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
@Autowired
private EmbeddingConfig embeddingConfig;
private RestTemplate restTemplate = new RestTemplate();
public void addCollection(String collectionName) {
@@ -36,6 +33,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
if (collections.contains(collectionName)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/create_collection?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
doRequest(url, null, HttpMethod.GET);
@@ -45,6 +43,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
if (CollectionUtils.isEmpty(queries)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/add_query?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
doRequest(url, JSONObject.toJSONString(queries, SerializerFeature.WriteMapNullValue), HttpMethod.POST);
@@ -54,6 +53,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
if (CollectionUtils.isEmpty(queries)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
List<String> queryIds = queries.stream().map(EmbeddingQuery::getQueryId).collect(Collectors.toList());
String url = String.format("%s/delete_query_by_ids?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
@@ -61,6 +61,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
}
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/retrieve_query?collection_name=%s&n_results=%s",
embeddingConfig.getUrl(), collectionName, num);
ResponseEntity<String> responseEntity = doRequest(url, JSONObject.toJSONString(retrieveQuery,
@@ -72,6 +73,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
}
private List<String> getCollectionList() {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = embeddingConfig.getUrl() + "/list_collections";
ResponseEntity<String> responseEntity = doRequest(url, null, HttpMethod.GET);
if (!responseEntity.hasBody()) {