mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
(improvement)(common) Optimized the structure of the common package. (#1164)
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public class ComponentFactory {
|
||||
|
||||
private static S2EmbeddingStore s2EmbeddingStore;
|
||||
|
||||
public static S2EmbeddingStore getS2EmbeddingStore() {
|
||||
if (Objects.isNull(s2EmbeddingStore)) {
|
||||
s2EmbeddingStore = init(S2EmbeddingStore.class);
|
||||
}
|
||||
return s2EmbeddingStore;
|
||||
}
|
||||
|
||||
private static <T> T init(Class<T> factoryType) {
|
||||
return SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()).get(0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class EmbeddingCollection {
|
||||
|
||||
private String id;
|
||||
|
||||
private String name;
|
||||
|
||||
private Map<String, String> metaData;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.common.pojo.DataItem;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class EmbeddingQuery {
|
||||
|
||||
private String queryId;
|
||||
|
||||
private String query;
|
||||
|
||||
private Map<String, Object> metadata;
|
||||
|
||||
private List<Double> queryEmbedding;
|
||||
|
||||
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
|
||||
return dataItems.stream().map(dataItem -> {
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(
|
||||
dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||
embeddingQuery.setQuery(dataItem.getName());
|
||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||
embeddingQuery.setMetadata(meta);
|
||||
embeddingQuery.setQueryEmbedding(null);
|
||||
return embeddingQuery;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
|
||||
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
|
||||
|
||||
@Override
|
||||
public InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
|
||||
Type type = new TypeToken<InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<EmbeddingQuery>>() {
|
||||
}.getType();
|
||||
return new Gson().fromJson(json, type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toJson(InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<?> store) {
|
||||
return new Gson().toJson(store);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import dev.langchain4j.store.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
|
||||
|
||||
public interface InMemoryEmbeddingStoreJsonCodec {
|
||||
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json);
|
||||
|
||||
String toJson(InMemoryEmbeddingStore<?> store);
|
||||
}
|
||||
@@ -0,0 +1,351 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArraySet;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.nio.file.StandardOpenOption.CREATE;
|
||||
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
|
||||
import static java.util.Comparator.comparingDouble;
|
||||
|
||||
/***
|
||||
* Implementation of S2EmbeddingStore within the Java process's in-memory.
|
||||
*/
|
||||
@Slf4j
|
||||
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||
|
||||
public static final String PERSISTENT_FILE_PRE = "InMemory.";
|
||||
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
|
||||
new ConcurrentHashMap<>();
|
||||
|
||||
@Override
|
||||
public synchronized void addCollection(String collectionName) {
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = null;
|
||||
Path filePath = getPersistentPath(collectionName);
|
||||
try {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
|
||||
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
|
||||
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
|
||||
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
|
||||
log.info("embeddingStore reload from file:{}", filePath);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("load persistentFile error, persistentFile:" + filePath, e);
|
||||
}
|
||||
if (Objects.isNull(embeddingStore)) {
|
||||
embeddingStore = new InMemoryEmbeddingStore();
|
||||
}
|
||||
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
|
||||
}
|
||||
|
||||
private Path getPersistentPath(String collectionName) {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String persistentFile = PERSISTENT_FILE_PRE + collectionName;
|
||||
return Paths.get(embeddingConfig.getEmbeddingStorePersistentPath(), persistentFile);
|
||||
}
|
||||
|
||||
public void persistentToFile() {
|
||||
for (Entry<String, InMemoryEmbeddingStore<EmbeddingQuery>> entry : collectionNameToStore.entrySet()) {
|
||||
Path filePath = getPersistentPath(entry.getKey());
|
||||
try {
|
||||
Path directoryPath = filePath.getParent();
|
||||
if (!Files.exists(directoryPath)) {
|
||||
Files.createDirectories(directoryPath);
|
||||
}
|
||||
entry.getValue().serializeToFile(filePath);
|
||||
} catch (Exception e) {
|
||||
log.error("persistentToFile error, persistentFile:" + filePath, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
|
||||
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
|
||||
for (EmbeddingQuery query : queries) {
|
||||
String question = query.getQuery();
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
embeddingStore.add(query.getQueryId(), embedding, query);
|
||||
}
|
||||
}
|
||||
|
||||
private InMemoryEmbeddingStore<EmbeddingQuery> getEmbeddingStore(String collectionName) {
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
|
||||
if (Objects.isNull(embeddingStore)) {
|
||||
synchronized (InMemoryS2EmbeddingStore.class) {
|
||||
addCollection(collectionName);
|
||||
embeddingStore = collectionNameToStore.get(collectionName);
|
||||
}
|
||||
}
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
//not support in InMemoryEmbeddingStore
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
|
||||
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
|
||||
|
||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
for (String queryText : queryTextsList) {
|
||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||
int maxResults = getMaxResults(num, filterCondition);
|
||||
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, maxResults);
|
||||
|
||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||
retrieveQueryResult.setQuery(queryText);
|
||||
List<Retrieval> retrievals = new ArrayList<>();
|
||||
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
|
||||
Retrieval retrieval = new Retrieval();
|
||||
retrieval.setDistance(1 - embeddingMatch.score());
|
||||
retrieval.setId(embeddingMatch.embeddingId());
|
||||
retrieval.setQuery(embeddingMatch.embedded().getQuery());
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
if (Objects.nonNull(embeddingMatch.embedded())
|
||||
&& MapUtils.isNotEmpty(embeddingMatch.embedded().getMetadata())) {
|
||||
metadata.putAll(embeddingMatch.embedded().getMetadata());
|
||||
}
|
||||
if (filterRetrieval(filterCondition, metadata)) {
|
||||
continue;
|
||||
}
|
||||
retrieval.setMetadata(metadata);
|
||||
retrievals.add(retrieval);
|
||||
}
|
||||
retrievals = retrievals.stream()
|
||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
|
||||
.limit(num)
|
||||
.collect(Collectors.toList());
|
||||
retrieveQueryResult.setRetrieval(retrievals);
|
||||
results.add(retrieveQueryResult);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
private int getMaxResults(int num, Map<String, String> filterCondition) {
|
||||
int maxResults = num;
|
||||
if (MapUtils.isNotEmpty(filterCondition)) {
|
||||
maxResults = num * 5;
|
||||
}
|
||||
return maxResults;
|
||||
}
|
||||
|
||||
private boolean filterRetrieval(Map<String, String> filterCondition, Map<String, Object> metadata) {
|
||||
if (MapUtils.isNotEmpty(metadata) && MapUtils.isNotEmpty(filterCondition)) {
|
||||
for (Entry<String, Object> entry : metadata.entrySet()) {
|
||||
String filterValue = filterCondition.get(entry.getKey());
|
||||
if (StringUtils.isNotBlank(filterValue) && !filterValue.equalsIgnoreCase(
|
||||
entry.getValue().toString())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* An {@link EmbeddingStore} that stores embeddings in memory.
|
||||
* <p>
|
||||
* Uses a brute force approach by iterating over all embeddings to find the best matches.
|
||||
*
|
||||
* @param <Embedded> The class of the object that has been embedded.
|
||||
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
|
||||
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
|
||||
* and fix concurrentModificationException in a multi-threaded environment
|
||||
*/
|
||||
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
|
||||
|
||||
private static class Entry<Embedded> {
|
||||
|
||||
String id;
|
||||
Embedding embedding;
|
||||
Embedded embedded;
|
||||
|
||||
Entry(String id, Embedding embedding, Embedded embedded) {
|
||||
this.id = id;
|
||||
this.embedding = embedding;
|
||||
this.embedded = embedded;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
InMemoryEmbeddingStore.Entry<?> that = (InMemoryEmbeddingStore.Entry<?>) o;
|
||||
return Objects.equals(this.id, that.id)
|
||||
&& Objects.equals(this.embedding, that.embedding)
|
||||
&& Objects.equals(this.embedded, that.embedded);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(id, embedding, embedded);
|
||||
}
|
||||
}
|
||||
|
||||
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
|
||||
private Set<Entry<Embedded>> entries = new CopyOnWriteArraySet<>();
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding) {
|
||||
String id = randomUUID();
|
||||
add(id, embedding);
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(String id, Embedding embedding) {
|
||||
add(id, embedding, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String add(Embedding embedding, Embedded embedded) {
|
||||
String id = randomUUID();
|
||||
add(id, embedding, embedded);
|
||||
return id;
|
||||
}
|
||||
|
||||
public void add(String id, Embedding embedding, Embedded embedded) {
|
||||
entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
List<String> ids = new ArrayList<>();
|
||||
for (Embedding embedding : embeddings) {
|
||||
ids.add(add(embedding));
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
|
||||
if (embeddings.size() != embedded.size()) {
|
||||
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
|
||||
}
|
||||
|
||||
List<String> ids = new ArrayList<>();
|
||||
for (int i = 0; i < embeddings.size(); i++) {
|
||||
ids.add(add(embeddings.get(i), embedded.get(i)));
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingMatch<Embedded>> findRelevant(Embedding referenceEmbedding, int maxResults,
|
||||
double minScore) {
|
||||
|
||||
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
|
||||
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
|
||||
|
||||
for (InMemoryEmbeddingStore.Entry<Embedded> entry : entries) {
|
||||
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
|
||||
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
||||
if (score >= minScore) {
|
||||
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
|
||||
if (matches.size() > maxResults) {
|
||||
matches.poll();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
List<EmbeddingMatch<Embedded>> result = new ArrayList<>(matches);
|
||||
result.sort(comparator);
|
||||
Collections.reverse(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
InMemoryEmbeddingStore<?> that = (InMemoryEmbeddingStore<?>) o;
|
||||
return Objects.equals(this.entries, that.entries);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(entries);
|
||||
}
|
||||
|
||||
public String serializeToJson() {
|
||||
return CODEC.toJson(this);
|
||||
}
|
||||
|
||||
public void serializeToFile(Path filePath) {
|
||||
try {
|
||||
String json = serializeToJson();
|
||||
Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public void serializeToFile(String filePath) {
|
||||
serializeToFile(Paths.get(filePath));
|
||||
}
|
||||
|
||||
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
|
||||
// fallback to default
|
||||
return new GsonInMemoryEmbeddingStoreJsonCodec();
|
||||
}
|
||||
|
||||
public static InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
|
||||
return CODEC.fromJson(json);
|
||||
}
|
||||
|
||||
public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(Path filePath) {
|
||||
try {
|
||||
String json = new String(Files.readAllBytes(filePath));
|
||||
return fromJson(json);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(String filePath) {
|
||||
return fromFile(Paths.get(filePath));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/***
|
||||
* Implementation of calling the Python service S2EmbeddingStore.
|
||||
*/
|
||||
@Slf4j
|
||||
public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
|
||||
|
||||
private RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
public void addCollection(String collectionName) {
|
||||
List<String> collections = getCollectionList();
|
||||
if (collections.contains(collectionName)) {
|
||||
return;
|
||||
}
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String url = String.format("%s/create_collection?collection_name=%s",
|
||||
embeddingConfig.getUrl(), collectionName);
|
||||
doRequest(url, null, HttpMethod.GET);
|
||||
}
|
||||
|
||||
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
if (CollectionUtils.isEmpty(queries)) {
|
||||
return;
|
||||
}
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String url = String.format("%s/add_query?collection_name=%s",
|
||||
embeddingConfig.getUrl(), collectionName);
|
||||
doRequest(url, JSONObject.toJSONString(queries, SerializerFeature.WriteMapNullValue), HttpMethod.POST);
|
||||
}
|
||||
|
||||
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
if (CollectionUtils.isEmpty(queries)) {
|
||||
return;
|
||||
}
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
List<String> queryIds = queries.stream().map(EmbeddingQuery::getQueryId).collect(Collectors.toList());
|
||||
String url = String.format("%s/delete_query_by_ids?collection_name=%s",
|
||||
embeddingConfig.getUrl(), collectionName);
|
||||
doRequest(url, JSONObject.toJSONString(queryIds), HttpMethod.POST);
|
||||
}
|
||||
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String url = String.format("%s/retrieve_query?collection_name=%s&n_results=%s",
|
||||
embeddingConfig.getUrl(), collectionName, num);
|
||||
ResponseEntity<String> responseEntity = doRequest(url, JSONObject.toJSONString(retrieveQuery,
|
||||
SerializerFeature.WriteMapNullValue), HttpMethod.POST);
|
||||
if (!responseEntity.hasBody()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return JSONObject.parseArray(responseEntity.getBody(), RetrieveQueryResult.class);
|
||||
}
|
||||
|
||||
private List<String> getCollectionList() {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String url = embeddingConfig.getUrl() + "/list_collections";
|
||||
ResponseEntity<String> responseEntity = doRequest(url, null, HttpMethod.GET);
|
||||
if (!responseEntity.hasBody()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<EmbeddingCollection> embeddingCollections = JSONObject.parseArray(responseEntity.getBody(),
|
||||
EmbeddingCollection.class);
|
||||
return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) {
|
||||
try {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.setLocation(URI.create(url));
|
||||
URI requestUrl = UriComponentsBuilder
|
||||
.fromHttpUrl(url).build().encode().toUri();
|
||||
HttpEntity<String> entity = new HttpEntity<>(headers);
|
||||
if (jsonBody != null) {
|
||||
log.info("[embedding] request body :{}", jsonBody);
|
||||
entity = new HttpEntity<>(jsonBody, headers);
|
||||
}
|
||||
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
||||
httpMethod, entity, new ParameterizedTypeReference<String>() {
|
||||
});
|
||||
log.info("[embedding] url :{} result body:{}", url, responseEntity);
|
||||
return responseEntity;
|
||||
} catch (Throwable e) {
|
||||
log.warn("doRequest service failed, url:" + url, e);
|
||||
}
|
||||
return ResponseEntity.of(Optional.empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class Retrieval {
|
||||
|
||||
protected String id;
|
||||
|
||||
protected double distance;
|
||||
|
||||
protected String query;
|
||||
|
||||
protected Map<String, Object> metadata;
|
||||
|
||||
public static Long getLongId(Object id) {
|
||||
if (id == null || StringUtils.isBlank(id.toString())) {
|
||||
return null;
|
||||
}
|
||||
String[] split = id.toString().split(Constants.UNDERLINE);
|
||||
return Long.parseLong(split[0]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
Retrieval retrieval = (Retrieval) o;
|
||||
return Double.compare(retrieval.distance, distance) == 0 && Objects.equal(id,
|
||||
retrieval.id) && Objects.equal(query, retrieval.query)
|
||||
&& Objects.equal(metadata, retrieval.metadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(id, distance, query, metadata);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class RetrieveQuery {
|
||||
|
||||
private List<String> queryTextsList;
|
||||
|
||||
private Map<String, String> filterCondition;
|
||||
|
||||
private List<List<Double>> queryEmbeddings;
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class RetrieveQueryResult {
|
||||
|
||||
private String query;
|
||||
|
||||
private List<Retrieval> retrieval;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Supersonic EmbeddingStore
|
||||
* Enhanced the functionality by enabling the addition and querying of collection names.
|
||||
*/
|
||||
public interface S2EmbeddingStore {
|
||||
|
||||
void addCollection(String collectionName);
|
||||
|
||||
void addQuery(String collectionName, List<EmbeddingQuery> queries);
|
||||
|
||||
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
|
||||
|
||||
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user