[improvement](python) LLM related services support Java service invocation (#484)

This commit is contained in:
lexluo09
2023-12-08 19:24:58 +08:00
committed by GitHub
parent 6c0f88d8b5
commit abbe8c84a1
33 changed files with 1037 additions and 95 deletions

View File

@@ -29,6 +29,7 @@ public class EmbeddingConfig {
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
private String metricAnalyzeQueryCollection;
@Value("${embedding.metric.analyzeQuery.nResult:5}")
private int metricAnalyzeQueryResultNum;
}

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.common.util;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import java.util.Objects;
import org.springframework.core.io.support.SpringFactoriesLoader;
public class ComponentFactory {
private static S2EmbeddingStore s2EmbeddingStore;
public static S2EmbeddingStore getS2EmbeddingStore() {
if (Objects.isNull(s2EmbeddingStore)) {
s2EmbeddingStore = init(S2EmbeddingStore.class);
}
return s2EmbeddingStore;
}
private static <T> T init(Class<T> factoryType) {
return SpringFactoriesLoader.loadFactories(factoryType,
Thread.currentThread().getContextClassLoader()).get(0);
}
}

View File

@@ -13,7 +13,7 @@ public class EmbeddingQuery {
private String query;
private Map<String, String> metadata;
private Map<String, Object> metadata;
private List<Double> queryEmbedding;

View File

@@ -0,0 +1,83 @@
package com.tencent.supersonic.common.util.embedding;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
new ConcurrentHashMap<>();
@Override
public void addCollection(String collectionName) {
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
}
@Override
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
for (EmbeddingQuery query : queries) {
String question = query.getQuery();
Embedding embedding = embeddingModel.embed(question).content();
embeddingStore.add(query.getQueryId(), embedding, query);
}
}
private InMemoryEmbeddingStore<EmbeddingQuery> getEmbeddingStore(String collectionName) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
if (Objects.isNull(embeddingStore)) {
synchronized (InMemoryS2EmbeddingStore.class) {
addCollection(collectionName);
embeddingStore = collectionNameToStore.get(collectionName);
}
}
return embeddingStore;
}
@Override
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
//not support in InMemoryEmbeddingStore
}
@Override
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
List<RetrieveQueryResult> results = new ArrayList<>();
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
for (String queryText : queryTextsList) {
Embedding embeddedText = embeddingModel.embed(queryText).content();
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, num);
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText);
List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
Retrieval retrieval = new Retrieval();
retrieval.setDistance(embeddingMatch.score());
retrieval.setId(embeddingMatch.embeddingId());
retrieval.setQuery(embeddingMatch.embedded().getQuery());
retrieval.setMetadata(embeddingMatch.embedded().getMetadata());
retrievals.add(retrieval);
}
retrieveQueryResult.setRetrieval(retrievals);
results.add(retrieveQueryResult);
}
return results;
}
}

View File

@@ -4,6 +4,10 @@ import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import java.net.URI;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.ParameterizedTypeReference;
@@ -12,18 +16,12 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
@Slf4j
@Component
public class EmbeddingUtils {
public class PythonS2EmbeddingStore implements S2EmbeddingStore {
@Autowired
private EmbeddingConfig embeddingConfig;
@@ -103,6 +101,5 @@ public class EmbeddingUtils {
}
return ResponseEntity.of(Optional.empty());
}
}

View File

@@ -15,7 +15,7 @@ public class Retrieval {
protected String query;
protected Map<String, String> metadata;
protected Map<String, Object> metadata;
public static Long getLongId(String id) {
if (StringUtils.isBlank(id)) {

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.common.util.embedding;
import java.util.List;
/**
* Supersonic EmbeddingStore
* Added the functionality of adding and querying collection names.
*/
public interface S2EmbeddingStore {
void addCollection(String collectionName);
void addQuery(String collectionName, List<EmbeddingQuery> queries);
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
}