mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:18:23 +00:00
[improvement](chat) Add persistence to InMemoryEmbeddingStore, fix the issue of PythonServiceS2EmbeddingStore being empty. (#524)
This commit is contained in:
@@ -37,7 +37,7 @@ public class JavaLLMProxy implements LLMProxy {
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode()));
|
||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
|
||||
|
||||
@@ -69,5 +69,14 @@ public class LLMReq {
|
||||
return name;
|
||||
}
|
||||
|
||||
public static SqlGenerationMode getMode(String name) {
|
||||
for (SqlGenerationMode sqlGenerationMode : SqlGenerationMode.values()) {
|
||||
if (sqlGenerationMode.name.equals(name)) {
|
||||
return sqlGenerationMode;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.tencent.supersonic.common.util.embedding;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import java.lang.reflect.Type;
|
||||
|
||||
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
|
||||
|
||||
@Override
|
||||
public InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
|
||||
Type type = new TypeToken<InMemoryEmbeddingStore<TextSegment>>() {
|
||||
}.getType();
|
||||
return new Gson().fromJson(json, type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toJson(InMemoryEmbeddingStore<?> store) {
|
||||
return new Gson().toJson(store);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.common.util.embedding;
|
||||
|
||||
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
|
||||
public interface InMemoryEmbeddingStoreJsonCodec {
|
||||
InMemoryEmbeddingStore<TextSegment> fromJson(String json);
|
||||
|
||||
String toJson(InMemoryEmbeddingStore<?> store);
|
||||
}
|
||||
@@ -1,21 +1,23 @@
|
||||
package com.tencent.supersonic.common.util.embedding;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.nio.file.StandardOpenOption.CREATE;
|
||||
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
|
||||
import static java.util.Comparator.comparingDouble;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.spi.ServiceHelper;
|
||||
import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import dev.langchain4j.store.embedding.inmemory.GsonInMemoryEmbeddingStoreJsonCodec;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodec;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
@@ -40,7 +42,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
new ConcurrentHashMap<>();
|
||||
|
||||
@Override
|
||||
public void addCollection(String collectionName) {
|
||||
public synchronized void addCollection(String collectionName) {
|
||||
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
|
||||
}
|
||||
|
||||
@@ -164,7 +166,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
Entry<?> that = (Entry<?>) o;
|
||||
InMemoryEmbeddingStore.Entry<?> that = (InMemoryEmbeddingStore.Entry<?>) o;
|
||||
return Objects.equals(this.id, that.id)
|
||||
&& Objects.equals(this.embedding, that.embedding)
|
||||
&& Objects.equals(this.embedded, that.embedded);
|
||||
@@ -176,7 +178,8 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
}
|
||||
}
|
||||
|
||||
private final List<Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
|
||||
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
|
||||
private final List<InMemoryEmbeddingStore.Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
@@ -198,7 +201,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
}
|
||||
|
||||
public void add(String id, Embedding embedding, Embedded embedded) {
|
||||
entries.add(new Entry<>(id, embedding, embedded));
|
||||
entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded));
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -230,7 +233,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
|
||||
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
|
||||
|
||||
for (Entry<Embedded> entry : entries) {
|
||||
for (InMemoryEmbeddingStore.Entry<Embedded> entry : entries) {
|
||||
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
|
||||
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
||||
if (score >= minScore) {
|
||||
@@ -264,16 +267,44 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
return Objects.hash(entries);
|
||||
}
|
||||
|
||||
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
|
||||
Collection<InMemoryEmbeddingStoreJsonCodecFactory> factories = ServiceHelper.loadFactories(
|
||||
InMemoryEmbeddingStoreJsonCodecFactory.class);
|
||||
for (InMemoryEmbeddingStoreJsonCodecFactory factory : factories) {
|
||||
return factory.create();
|
||||
public String serializeToJson() {
|
||||
return CODEC.toJson(this);
|
||||
}
|
||||
|
||||
public void serializeToFile(Path filePath) {
|
||||
try {
|
||||
String json = serializeToJson();
|
||||
Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public void serializeToFile(String filePath) {
|
||||
serializeToFile(Paths.get(filePath));
|
||||
}
|
||||
|
||||
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
|
||||
// fallback to default
|
||||
return new GsonInMemoryEmbeddingStoreJsonCodec();
|
||||
}
|
||||
|
||||
public static InMemoryEmbeddingStore<TextSegment> fromJson(String json) {
|
||||
return CODEC.fromJson(json);
|
||||
}
|
||||
|
||||
public static InMemoryEmbeddingStore<TextSegment> fromFile(Path filePath) {
|
||||
try {
|
||||
String json = new String(Files.readAllBytes(filePath));
|
||||
return fromJson(json);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static InMemoryEmbeddingStore<TextSegment> fromFile(String filePath) {
|
||||
return fromFile(Paths.get(filePath));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,12 +4,12 @@ import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -26,9 +26,6 @@ import org.springframework.web.util.UriComponentsBuilder;
|
||||
@Slf4j
|
||||
public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
public void addCollection(String collectionName) {
|
||||
@@ -36,6 +33,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
||||
if (collections.contains(collectionName)) {
|
||||
return;
|
||||
}
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String url = String.format("%s/create_collection?collection_name=%s",
|
||||
embeddingConfig.getUrl(), collectionName);
|
||||
doRequest(url, null, HttpMethod.GET);
|
||||
@@ -45,6 +43,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
||||
if (CollectionUtils.isEmpty(queries)) {
|
||||
return;
|
||||
}
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String url = String.format("%s/add_query?collection_name=%s",
|
||||
embeddingConfig.getUrl(), collectionName);
|
||||
doRequest(url, JSONObject.toJSONString(queries, SerializerFeature.WriteMapNullValue), HttpMethod.POST);
|
||||
@@ -54,6 +53,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
||||
if (CollectionUtils.isEmpty(queries)) {
|
||||
return;
|
||||
}
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
List<String> queryIds = queries.stream().map(EmbeddingQuery::getQueryId).collect(Collectors.toList());
|
||||
String url = String.format("%s/delete_query_by_ids?collection_name=%s",
|
||||
embeddingConfig.getUrl(), collectionName);
|
||||
@@ -61,6 +61,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
||||
}
|
||||
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String url = String.format("%s/retrieve_query?collection_name=%s&n_results=%s",
|
||||
embeddingConfig.getUrl(), collectionName, num);
|
||||
ResponseEntity<String> responseEntity = doRequest(url, JSONObject.toJSONString(retrieveQuery,
|
||||
@@ -72,6 +73,7 @@ public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
||||
}
|
||||
|
||||
private List<String> getCollectionList() {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String url = embeddingConfig.getUrl() + "/list_collections";
|
||||
ResponseEntity<String> responseEntity = doRequest(url, null, HttpMethod.GET);
|
||||
if (!responseEntity.hasBody()) {
|
||||
|
||||
Reference in New Issue
Block a user