From fe75b3e3935631809c92cf64ed801ab0758fc093 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sun, 17 Dec 2023 11:04:29 +0800 Subject: [PATCH] [improvement](chat) Add persistence to InMemoryEmbeddingStore, fix the issue of PythonServiceS2EmbeddingStore being empty. (#524) --- .../supersonic/chat/parser/JavaLLMProxy.java | 2 +- .../chat/query/llm/s2sql/LLMReq.java | 9 +++ .../GsonInMemoryEmbeddingStoreJsonCodec.java | 22 +++++++ .../InMemoryEmbeddingStoreJsonCodec.java | 10 +++ .../embedding/InMemoryS2EmbeddingStore.java | 61 ++++++++++++++----- .../PythonServiceS2EmbeddingStore.java | 10 +-- 6 files changed, 94 insertions(+), 20 deletions(-) create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/embedding/GsonInMemoryEmbeddingStoreJsonCodec.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryEmbeddingStoreJsonCodec.java diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java index 57083c176..1375b0d1c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java @@ -37,7 +37,7 @@ public class JavaLLMProxy implements LLMProxy { public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) { SqlGeneration sqlGeneration = SqlGenerationFactory.get( - SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode())); + SqlGenerationMode.getMode(llmReq.getSqlGenerationMode())); String modelName = llmReq.getSchema().getModelName(); Map sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java index c11676201..30d7fbf3f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java @@ -69,5 +69,14 @@ public class LLMReq { return name; } + public static SqlGenerationMode getMode(String name) { + for (SqlGenerationMode sqlGenerationMode : SqlGenerationMode.values()) { + if (sqlGenerationMode.name.equals(name)) { + return sqlGenerationMode; + } + } + return null; + } + } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/GsonInMemoryEmbeddingStoreJsonCodec.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/GsonInMemoryEmbeddingStoreJsonCodec.java new file mode 100644 index 000000000..483cfd4e7 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/GsonInMemoryEmbeddingStoreJsonCodec.java @@ -0,0 +1,22 @@ +package com.tencent.supersonic.common.util.embedding; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore; +import dev.langchain4j.data.segment.TextSegment; +import java.lang.reflect.Type; + +public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec { + + @Override + public InMemoryEmbeddingStore fromJson(String json) { + Type type = new TypeToken>() { + }.getType(); + return new Gson().fromJson(json, type); + } + + @Override + public String toJson(InMemoryEmbeddingStore store) { + return new Gson().toJson(store); + } +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryEmbeddingStoreJsonCodec.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryEmbeddingStoreJsonCodec.java new file mode 100644 index 000000000..b8ff372f4 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryEmbeddingStoreJsonCodec.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.common.util.embedding; + +import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore; +import dev.langchain4j.data.segment.TextSegment; + +public interface InMemoryEmbeddingStoreJsonCodec { + InMemoryEmbeddingStore fromJson(String json); + + String toJson(InMemoryEmbeddingStore store); +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java index c1b56a32c..78ec3be6f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java @@ -1,21 +1,23 @@ package com.tencent.supersonic.common.util.embedding; import static dev.langchain4j.internal.Utils.randomUUID; +import static java.nio.file.StandardOpenOption.CREATE; +import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; import static java.util.Comparator.comparingDouble; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.spi.ServiceHelper; -import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory; import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.RelevanceScore; -import dev.langchain4j.store.embedding.inmemory.GsonInMemoryEmbeddingStoreJsonCodec; -import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodec; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.List; @@ -40,7 +42,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { new ConcurrentHashMap<>(); @Override - public void addCollection(String collectionName) { + public synchronized void addCollection(String collectionName) { collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore()); } @@ -164,7 +166,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { if (o == null || getClass() != o.getClass()) { return false; } - Entry that = (Entry) o; + InMemoryEmbeddingStore.Entry that = (InMemoryEmbeddingStore.Entry) o; return Objects.equals(this.id, that.id) && Objects.equals(this.embedding, that.embedding) && Objects.equals(this.embedded, that.embedded); @@ -176,7 +178,8 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { } } - private final List> entries = new CopyOnWriteArrayList<>(); + private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec(); + private final List> entries = new CopyOnWriteArrayList<>(); @Override public String add(Embedding embedding) { @@ -198,7 +201,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { } public void add(String id, Embedding embedding, Embedded embedded) { - entries.add(new Entry<>(id, embedding, embedded)); + entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded)); } @Override @@ -230,7 +233,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { Comparator> comparator = comparingDouble(EmbeddingMatch::score); PriorityQueue> matches = new PriorityQueue<>(comparator); - for (Entry entry : entries) { + for (InMemoryEmbeddingStore.Entry entry : entries) { double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding); double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity); if (score >= minScore) { @@ -264,16 +267,44 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { return Objects.hash(entries); } - private static InMemoryEmbeddingStoreJsonCodec loadCodec() { - Collection factories = ServiceHelper.loadFactories( - InMemoryEmbeddingStoreJsonCodecFactory.class); - for (InMemoryEmbeddingStoreJsonCodecFactory factory : factories) { - return factory.create(); + public String serializeToJson() { + return CODEC.toJson(this); + } + + public void serializeToFile(Path filePath) { + try { + String json = serializeToJson(); + Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING); + } catch (IOException e) { + throw new RuntimeException(e); } + } + + public void serializeToFile(String filePath) { + serializeToFile(Paths.get(filePath)); + } + + private static InMemoryEmbeddingStoreJsonCodec loadCodec() { // fallback to default return new GsonInMemoryEmbeddingStoreJsonCodec(); } + public static InMemoryEmbeddingStore fromJson(String json) { + return CODEC.fromJson(json); + } + + public static InMemoryEmbeddingStore fromFile(Path filePath) { + try { + String json = new String(Files.readAllBytes(filePath)); + return fromJson(json); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static InMemoryEmbeddingStore fromFile(String filePath) { + return fromFile(Paths.get(filePath)); + } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/PythonServiceS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/PythonServiceS2EmbeddingStore.java index f109db651..a530b61fe 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/PythonServiceS2EmbeddingStore.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/PythonServiceS2EmbeddingStore.java @@ -4,12 +4,12 @@ import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.serializer.SerializerFeature; import com.google.common.collect.Lists; import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.util.ContextUtils; import java.net.URI; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; @@ -26,9 +26,6 @@ import org.springframework.web.util.UriComponentsBuilder; @Slf4j public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore { - @Autowired - private EmbeddingConfig embeddingConfig; - private RestTemplate restTemplate = new RestTemplate(); public void addCollection(String collectionName) { @@ -36,6 +33,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore { if (collections.contains(collectionName)) { return; } + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); String url = String.format("%s/create_collection?collection_name=%s", embeddingConfig.getUrl(), collectionName); doRequest(url, null, HttpMethod.GET); @@ -45,6 +43,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore { if (CollectionUtils.isEmpty(queries)) { return; } + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); String url = String.format("%s/add_query?collection_name=%s", embeddingConfig.getUrl(), collectionName); doRequest(url, JSONObject.toJSONString(queries, SerializerFeature.WriteMapNullValue), HttpMethod.POST); @@ -54,6 +53,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore { if (CollectionUtils.isEmpty(queries)) { return; } + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); List queryIds = queries.stream().map(EmbeddingQuery::getQueryId).collect(Collectors.toList()); String url = String.format("%s/delete_query_by_ids?collection_name=%s", embeddingConfig.getUrl(), collectionName); @@ -61,6 +61,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore { } public List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); String url = String.format("%s/retrieve_query?collection_name=%s&n_results=%s", embeddingConfig.getUrl(), collectionName, num); ResponseEntity responseEntity = doRequest(url, JSONObject.toJSONString(retrieveQuery, @@ -72,6 +73,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore { } private List getCollectionList() { + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); String url = embeddingConfig.getUrl() + "/list_collections"; ResponseEntity responseEntity = doRequest(url, null, HttpMethod.GET); if (!responseEntity.hasBody()) {