mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
[improvement](python) LLM related services support Java service invocation (#484)
This commit is contained in:
@@ -29,6 +29,7 @@ public class EmbeddingConfig {
|
||||
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
|
||||
private String metricAnalyzeQueryCollection;
|
||||
|
||||
|
||||
@Value("${embedding.metric.analyzeQuery.nResult:5}")
|
||||
private int metricAnalyzeQueryResultNum;
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import java.util.Objects;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
public class ComponentFactory {
|
||||
|
||||
private static S2EmbeddingStore s2EmbeddingStore;
|
||||
|
||||
public static S2EmbeddingStore getS2EmbeddingStore() {
|
||||
if (Objects.isNull(s2EmbeddingStore)) {
|
||||
s2EmbeddingStore = init(S2EmbeddingStore.class);
|
||||
}
|
||||
return s2EmbeddingStore;
|
||||
}
|
||||
|
||||
private static <T> T init(Class<T> factoryType) {
|
||||
return SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()).get(0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -13,7 +13,7 @@ public class EmbeddingQuery {
|
||||
|
||||
private String query;
|
||||
|
||||
private Map<String, String> metadata;
|
||||
private Map<String, Object> metadata;
|
||||
|
||||
private List<Double> queryEmbedding;
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
package com.tencent.supersonic.common.util.embedding;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
|
||||
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
|
||||
new ConcurrentHashMap<>();
|
||||
|
||||
@Override
|
||||
public void addCollection(String collectionName) {
|
||||
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
|
||||
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
|
||||
for (EmbeddingQuery query : queries) {
|
||||
String question = query.getQuery();
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
embeddingStore.add(query.getQueryId(), embedding, query);
|
||||
}
|
||||
}
|
||||
|
||||
private InMemoryEmbeddingStore<EmbeddingQuery> getEmbeddingStore(String collectionName) {
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
|
||||
if (Objects.isNull(embeddingStore)) {
|
||||
synchronized (InMemoryS2EmbeddingStore.class) {
|
||||
addCollection(collectionName);
|
||||
embeddingStore = collectionNameToStore.get(collectionName);
|
||||
}
|
||||
|
||||
}
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
//not support in InMemoryEmbeddingStore
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
|
||||
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
|
||||
|
||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
for (String queryText : queryTextsList) {
|
||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, num);
|
||||
|
||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||
retrieveQueryResult.setQuery(queryText);
|
||||
List<Retrieval> retrievals = new ArrayList<>();
|
||||
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
|
||||
Retrieval retrieval = new Retrieval();
|
||||
retrieval.setDistance(embeddingMatch.score());
|
||||
retrieval.setId(embeddingMatch.embeddingId());
|
||||
retrieval.setQuery(embeddingMatch.embedded().getQuery());
|
||||
retrieval.setMetadata(embeddingMatch.embedded().getMetadata());
|
||||
retrievals.add(retrieval);
|
||||
}
|
||||
retrieveQueryResult.setRetrieval(retrievals);
|
||||
results.add(retrieveQueryResult);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,10 @@ import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
@@ -12,18 +16,12 @@ import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class EmbeddingUtils {
|
||||
public class PythonS2EmbeddingStore implements S2EmbeddingStore {
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
@@ -103,6 +101,5 @@ public class EmbeddingUtils {
|
||||
}
|
||||
return ResponseEntity.of(Optional.empty());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ public class Retrieval {
|
||||
|
||||
protected String query;
|
||||
|
||||
protected Map<String, String> metadata;
|
||||
protected Map<String, Object> metadata;
|
||||
|
||||
public static Long getLongId(String id) {
|
||||
if (StringUtils.isBlank(id)) {
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.common.util.embedding;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Supersonic EmbeddingStore
|
||||
* Added the functionality of adding and querying collection names.
|
||||
*/
|
||||
public interface S2EmbeddingStore {
|
||||
|
||||
void addCollection(String collectionName);
|
||||
|
||||
void addQuery(String collectionName, List<EmbeddingQuery> queries);
|
||||
|
||||
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
|
||||
|
||||
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user