mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement](chat) Add persistence to InMemoryEmbeddingStore, fix the issue of PythonServiceS2EmbeddingStore being empty. (#524)
This commit is contained in:
@@ -37,7 +37,7 @@ public class JavaLLMProxy implements LLMProxy {
|
|||||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||||
|
|
||||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||||
SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode()));
|
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||||
String modelName = llmReq.getSchema().getModelName();
|
String modelName = llmReq.getSchema().getModelName();
|
||||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||||
|
|
||||||
|
|||||||
@@ -69,5 +69,14 @@ public class LLMReq {
|
|||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static SqlGenerationMode getMode(String name) {
|
||||||
|
for (SqlGenerationMode sqlGenerationMode : SqlGenerationMode.values()) {
|
||||||
|
if (sqlGenerationMode.name.equals(name)) {
|
||||||
|
return sqlGenerationMode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
|
import com.google.gson.Gson;
|
||||||
|
import com.google.gson.reflect.TypeToken;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import java.lang.reflect.Type;
|
||||||
|
|
||||||
|
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
|
||||||
|
Type type = new TypeToken<InMemoryEmbeddingStore<TextSegment>>() {
|
||||||
|
}.getType();
|
||||||
|
return new Gson().fromJson(json, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toJson(InMemoryEmbeddingStore<?> store) {
|
||||||
|
return new Gson().toJson(store);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
|
||||||
|
public interface InMemoryEmbeddingStoreJsonCodec {
|
||||||
|
InMemoryEmbeddingStore<TextSegment> fromJson(String json);
|
||||||
|
|
||||||
|
String toJson(InMemoryEmbeddingStore<?> store);
|
||||||
|
}
|
||||||
@@ -1,21 +1,23 @@
|
|||||||
package com.tencent.supersonic.common.util.embedding;
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||||
|
import static java.nio.file.StandardOpenOption.CREATE;
|
||||||
|
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
|
||||||
import static java.util.Comparator.comparingDouble;
|
import static java.util.Comparator.comparingDouble;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.spi.ServiceHelper;
|
|
||||||
import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
|
|
||||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||||
import dev.langchain4j.store.embedding.inmemory.GsonInMemoryEmbeddingStoreJsonCodec;
|
import java.io.IOException;
|
||||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodec;
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -40,7 +42,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
new ConcurrentHashMap<>();
|
new ConcurrentHashMap<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addCollection(String collectionName) {
|
public synchronized void addCollection(String collectionName) {
|
||||||
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
|
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,7 +166,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
if (o == null || getClass() != o.getClass()) {
|
if (o == null || getClass() != o.getClass()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
Entry<?> that = (Entry<?>) o;
|
InMemoryEmbeddingStore.Entry<?> that = (InMemoryEmbeddingStore.Entry<?>) o;
|
||||||
return Objects.equals(this.id, that.id)
|
return Objects.equals(this.id, that.id)
|
||||||
&& Objects.equals(this.embedding, that.embedding)
|
&& Objects.equals(this.embedding, that.embedding)
|
||||||
&& Objects.equals(this.embedded, that.embedded);
|
&& Objects.equals(this.embedded, that.embedded);
|
||||||
@@ -176,7 +178,8 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private final List<Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
|
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
|
||||||
|
private final List<InMemoryEmbeddingStore.Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String add(Embedding embedding) {
|
public String add(Embedding embedding) {
|
||||||
@@ -198,7 +201,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void add(String id, Embedding embedding, Embedded embedded) {
|
public void add(String id, Embedding embedding, Embedded embedded) {
|
||||||
entries.add(new Entry<>(id, embedding, embedded));
|
entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -230,7 +233,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
|
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
|
||||||
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
|
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
|
||||||
|
|
||||||
for (Entry<Embedded> entry : entries) {
|
for (InMemoryEmbeddingStore.Entry<Embedded> entry : entries) {
|
||||||
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
|
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
|
||||||
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
||||||
if (score >= minScore) {
|
if (score >= minScore) {
|
||||||
@@ -264,16 +267,44 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
return Objects.hash(entries);
|
return Objects.hash(entries);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
|
public String serializeToJson() {
|
||||||
Collection<InMemoryEmbeddingStoreJsonCodecFactory> factories = ServiceHelper.loadFactories(
|
return CODEC.toJson(this);
|
||||||
InMemoryEmbeddingStoreJsonCodecFactory.class);
|
}
|
||||||
for (InMemoryEmbeddingStoreJsonCodecFactory factory : factories) {
|
|
||||||
return factory.create();
|
public void serializeToFile(Path filePath) {
|
||||||
|
try {
|
||||||
|
String json = serializeToJson();
|
||||||
|
Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void serializeToFile(String filePath) {
|
||||||
|
serializeToFile(Paths.get(filePath));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
|
||||||
// fallback to default
|
// fallback to default
|
||||||
return new GsonInMemoryEmbeddingStoreJsonCodec();
|
return new GsonInMemoryEmbeddingStoreJsonCodec();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
|
||||||
|
return CODEC.fromJson(json);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static InMemoryEmbeddingStore<TextSegment> fromFile(Path filePath) {
|
||||||
|
try {
|
||||||
|
String json = new String(Files.readAllBytes(filePath));
|
||||||
|
return fromJson(json);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static InMemoryEmbeddingStore<TextSegment> fromFile(String filePath) {
|
||||||
|
return fromFile(Paths.get(filePath));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import com.alibaba.fastjson.JSONObject;
|
|||||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.core.ParameterizedTypeReference;
|
import org.springframework.core.ParameterizedTypeReference;
|
||||||
import org.springframework.http.HttpEntity;
|
import org.springframework.http.HttpEntity;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
@@ -26,9 +26,6 @@ import org.springframework.web.util.UriComponentsBuilder;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private EmbeddingConfig embeddingConfig;
|
|
||||||
|
|
||||||
private RestTemplate restTemplate = new RestTemplate();
|
private RestTemplate restTemplate = new RestTemplate();
|
||||||
|
|
||||||
public void addCollection(String collectionName) {
|
public void addCollection(String collectionName) {
|
||||||
@@ -36,6 +33,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
if (collections.contains(collectionName)) {
|
if (collections.contains(collectionName)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
String url = String.format("%s/create_collection?collection_name=%s",
|
String url = String.format("%s/create_collection?collection_name=%s",
|
||||||
embeddingConfig.getUrl(), collectionName);
|
embeddingConfig.getUrl(), collectionName);
|
||||||
doRequest(url, null, HttpMethod.GET);
|
doRequest(url, null, HttpMethod.GET);
|
||||||
@@ -45,6 +43,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
if (CollectionUtils.isEmpty(queries)) {
|
if (CollectionUtils.isEmpty(queries)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
String url = String.format("%s/add_query?collection_name=%s",
|
String url = String.format("%s/add_query?collection_name=%s",
|
||||||
embeddingConfig.getUrl(), collectionName);
|
embeddingConfig.getUrl(), collectionName);
|
||||||
doRequest(url, JSONObject.toJSONString(queries, SerializerFeature.WriteMapNullValue), HttpMethod.POST);
|
doRequest(url, JSONObject.toJSONString(queries, SerializerFeature.WriteMapNullValue), HttpMethod.POST);
|
||||||
@@ -54,6 +53,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
if (CollectionUtils.isEmpty(queries)) {
|
if (CollectionUtils.isEmpty(queries)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
List<String> queryIds = queries.stream().map(EmbeddingQuery::getQueryId).collect(Collectors.toList());
|
List<String> queryIds = queries.stream().map(EmbeddingQuery::getQueryId).collect(Collectors.toList());
|
||||||
String url = String.format("%s/delete_query_by_ids?collection_name=%s",
|
String url = String.format("%s/delete_query_by_ids?collection_name=%s",
|
||||||
embeddingConfig.getUrl(), collectionName);
|
embeddingConfig.getUrl(), collectionName);
|
||||||
@@ -61,6 +61,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||||
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
String url = String.format("%s/retrieve_query?collection_name=%s&n_results=%s",
|
String url = String.format("%s/retrieve_query?collection_name=%s&n_results=%s",
|
||||||
embeddingConfig.getUrl(), collectionName, num);
|
embeddingConfig.getUrl(), collectionName, num);
|
||||||
ResponseEntity<String> responseEntity = doRequest(url, JSONObject.toJSONString(retrieveQuery,
|
ResponseEntity<String> responseEntity = doRequest(url, JSONObject.toJSONString(retrieveQuery,
|
||||||
@@ -72,6 +73,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private List<String> getCollectionList() {
|
private List<String> getCollectionList() {
|
||||||
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
String url = embeddingConfig.getUrl() + "/list_collections";
|
String url = embeddingConfig.getUrl() + "/list_collections";
|
||||||
ResponseEntity<String> responseEntity = doRequest(url, null, HttpMethod.GET);
|
ResponseEntity<String> responseEntity = doRequest(url, null, HttpMethod.GET);
|
||||||
if (!responseEntity.hasBody()) {
|
if (!responseEntity.hasBody()) {
|
||||||
|
|||||||
Reference in New Issue
Block a user