(improvement)(common) Optimized the structure of the common package. (#1164)

This commit is contained in:
lexluo09
2024-06-19 18:07:26 +08:00
committed by GitHub
parent 5d32235c2d
commit 48113b41dd
104 changed files with 202 additions and 199 deletions

View File

@@ -0,0 +1,23 @@
package dev.langchain4j.store.embedding;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.Objects;
public class ComponentFactory {
private static S2EmbeddingStore s2EmbeddingStore;
public static S2EmbeddingStore getS2EmbeddingStore() {
if (Objects.isNull(s2EmbeddingStore)) {
s2EmbeddingStore = init(S2EmbeddingStore.class);
}
return s2EmbeddingStore;
}
private static <T> T init(Class<T> factoryType) {
return SpringFactoriesLoader.loadFactories(factoryType,
Thread.currentThread().getContextClassLoader()).get(0);
}
}

View File

@@ -0,0 +1,17 @@
package dev.langchain4j.store.embedding;
import lombok.Data;
import java.util.Map;
@Data
public class EmbeddingCollection {
private String id;
private String name;
private Map<String, String> metaData;
}

View File

@@ -0,0 +1,36 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import lombok.Data;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Data
public class EmbeddingQuery {
private String queryId;
private String query;
private Map<String, Object> metadata;
private List<Double> queryEmbedding;
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
}
}

View File

@@ -0,0 +1,21 @@
package dev.langchain4j.store.embedding;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.lang.reflect.Type;
public class GsonInMemoryEmbeddingStoreJsonCodec implements InMemoryEmbeddingStoreJsonCodec {
@Override
public InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
Type type = new TypeToken<InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<EmbeddingQuery>>() {
}.getType();
return new Gson().fromJson(json, type);
}
@Override
public String toJson(InMemoryS2EmbeddingStore.InMemoryEmbeddingStore<?> store) {
return new Gson().toJson(store);
}
}

View File

@@ -0,0 +1,10 @@
package dev.langchain4j.store.embedding;
import dev.langchain4j.store.embedding.InMemoryS2EmbeddingStore.InMemoryEmbeddingStore;
public interface InMemoryEmbeddingStoreJsonCodec {
InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json);
String toJson(InMemoryEmbeddingStore<?> store);
}

View File

@@ -0,0 +1,351 @@
package dev.langchain4j.store.embedding;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.randomUUID;
import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
import static java.util.Comparator.comparingDouble;
/***
* Implementation of S2EmbeddingStore within the Java process's in-memory.
*/
@Slf4j
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
public static final String PERSISTENT_FILE_PRE = "InMemory.";
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
new ConcurrentHashMap<>();
@Override
public synchronized void addCollection(String collectionName) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = null;
Path filePath = getPersistentPath(collectionName);
try {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
log.info("embeddingStore reload from file:{}", filePath);
}
} catch (Exception e) {
log.error("load persistentFile error, persistentFile:" + filePath, e);
}
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
}
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
}
private Path getPersistentPath(String collectionName) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String persistentFile = PERSISTENT_FILE_PRE + collectionName;
return Paths.get(embeddingConfig.getEmbeddingStorePersistentPath(), persistentFile);
}
public void persistentToFile() {
for (Entry<String, InMemoryEmbeddingStore<EmbeddingQuery>> entry : collectionNameToStore.entrySet()) {
Path filePath = getPersistentPath(entry.getKey());
try {
Path directoryPath = filePath.getParent();
if (!Files.exists(directoryPath)) {
Files.createDirectories(directoryPath);
}
entry.getValue().serializeToFile(filePath);
} catch (Exception e) {
log.error("persistentToFile error, persistentFile:" + filePath, e);
}
}
}
@Override
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
for (EmbeddingQuery query : queries) {
String question = query.getQuery();
Embedding embedding = embeddingModel.embed(question).content();
embeddingStore.add(query.getQueryId(), embedding, query);
}
}
private InMemoryEmbeddingStore<EmbeddingQuery> getEmbeddingStore(String collectionName) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
if (Objects.isNull(embeddingStore)) {
synchronized (InMemoryS2EmbeddingStore.class) {
addCollection(collectionName);
embeddingStore = collectionNameToStore.get(collectionName);
}
}
return embeddingStore;
}
@Override
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
//not support in InMemoryEmbeddingStore
}
@Override
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
List<RetrieveQueryResult> results = new ArrayList<>();
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) {
Embedding embeddedText = embeddingModel.embed(queryText).content();
int maxResults = getMaxResults(num, filterCondition);
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, maxResults);
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText);
List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
Retrieval retrieval = new Retrieval();
retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(embeddingMatch.embeddingId());
retrieval.setQuery(embeddingMatch.embedded().getQuery());
Map<String, Object> metadata = new HashMap<>();
if (Objects.nonNull(embeddingMatch.embedded())
&& MapUtils.isNotEmpty(embeddingMatch.embedded().getMetadata())) {
metadata.putAll(embeddingMatch.embedded().getMetadata());
}
if (filterRetrieval(filterCondition, metadata)) {
continue;
}
retrieval.setMetadata(metadata);
retrievals.add(retrieval);
}
retrievals = retrievals.stream()
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
.limit(num)
.collect(Collectors.toList());
retrieveQueryResult.setRetrieval(retrievals);
results.add(retrieveQueryResult);
}
return results;
}
private int getMaxResults(int num, Map<String, String> filterCondition) {
int maxResults = num;
if (MapUtils.isNotEmpty(filterCondition)) {
maxResults = num * 5;
}
return maxResults;
}
private boolean filterRetrieval(Map<String, String> filterCondition, Map<String, Object> metadata) {
if (MapUtils.isNotEmpty(metadata) && MapUtils.isNotEmpty(filterCondition)) {
for (Entry<String, Object> entry : metadata.entrySet()) {
String filterValue = filterCondition.get(entry.getKey());
if (StringUtils.isNotBlank(filterValue) && !filterValue.equalsIgnoreCase(
entry.getValue().toString())) {
return true;
}
}
}
return false;
}
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
*
* @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
* and fix concurrentModificationException in a multi-threaded environment
*/
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
private static class Entry<Embedded> {
String id;
Embedding embedding;
Embedded embedded;
Entry(String id, Embedding embedding, Embedded embedded) {
this.id = id;
this.embedding = embedding;
this.embedded = embedded;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
InMemoryEmbeddingStore.Entry<?> that = (InMemoryEmbeddingStore.Entry<?>) o;
return Objects.equals(this.id, that.id)
&& Objects.equals(this.embedding, that.embedding)
&& Objects.equals(this.embedded, that.embedded);
}
@Override
public int hashCode() {
return Objects.hash(id, embedding, embedded);
}
}
private static final InMemoryEmbeddingStoreJsonCodec CODEC = loadCodec();
private Set<Entry<Embedded>> entries = new CopyOnWriteArraySet<>();
@Override
public String add(Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}
@Override
public void add(String id, Embedding embedding) {
add(id, embedding, null);
}
@Override
public String add(Embedding embedding, Embedded embedded) {
String id = randomUUID();
add(id, embedding, embedded);
return id;
}
public void add(String id, Embedding embedding, Embedded embedded) {
entries.add(new InMemoryEmbeddingStore.Entry<>(id, embedding, embedded));
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = new ArrayList<>();
for (Embedding embedding : embeddings) {
ids.add(add(embedding));
}
return ids;
}
@Override
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
if (embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
}
List<String> ids = new ArrayList<>();
for (int i = 0; i < embeddings.size(); i++) {
ids.add(add(embeddings.get(i), embedded.get(i)));
}
return ids;
}
@Override
public List<EmbeddingMatch<Embedded>> findRelevant(Embedding referenceEmbedding, int maxResults,
double minScore) {
Comparator<EmbeddingMatch<Embedded>> comparator = comparingDouble(EmbeddingMatch::score);
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
for (InMemoryEmbeddingStore.Entry<Embedded> entry : entries) {
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= minScore) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
if (matches.size() > maxResults) {
matches.poll();
}
}
}
List<EmbeddingMatch<Embedded>> result = new ArrayList<>(matches);
result.sort(comparator);
Collections.reverse(result);
return result;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
InMemoryEmbeddingStore<?> that = (InMemoryEmbeddingStore<?>) o;
return Objects.equals(this.entries, that.entries);
}
@Override
public int hashCode() {
return Objects.hash(entries);
}
public String serializeToJson() {
return CODEC.toJson(this);
}
public void serializeToFile(Path filePath) {
try {
String json = serializeToJson();
Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public void serializeToFile(String filePath) {
serializeToFile(Paths.get(filePath));
}
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
// fallback to default
return new GsonInMemoryEmbeddingStoreJsonCodec();
}
public static InMemoryEmbeddingStore<EmbeddingQuery> fromJson(String json) {
return CODEC.fromJson(json);
}
public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(Path filePath) {
try {
String json = new String(Files.readAllBytes(filePath));
return fromJson(json);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static InMemoryEmbeddingStore<EmbeddingQuery> fromFile(String filePath) {
return fromFile(Paths.get(filePath));
}
}
}

View File

@@ -0,0 +1,111 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/***
* Implementation of calling the Python service S2EmbeddingStore.
*/
@Slf4j
public class PythonServiceS2EmbeddingStore implements S2EmbeddingStore {
private RestTemplate restTemplate = new RestTemplate();
public void addCollection(String collectionName) {
List<String> collections = getCollectionList();
if (collections.contains(collectionName)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/create_collection?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
doRequest(url, null, HttpMethod.GET);
}
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
if (CollectionUtils.isEmpty(queries)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/add_query?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
doRequest(url, JSONObject.toJSONString(queries, SerializerFeature.WriteMapNullValue), HttpMethod.POST);
}
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
if (CollectionUtils.isEmpty(queries)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
List<String> queryIds = queries.stream().map(EmbeddingQuery::getQueryId).collect(Collectors.toList());
String url = String.format("%s/delete_query_by_ids?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
doRequest(url, JSONObject.toJSONString(queryIds), HttpMethod.POST);
}
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/retrieve_query?collection_name=%s&n_results=%s",
embeddingConfig.getUrl(), collectionName, num);
ResponseEntity<String> responseEntity = doRequest(url, JSONObject.toJSONString(retrieveQuery,
SerializerFeature.WriteMapNullValue), HttpMethod.POST);
if (!responseEntity.hasBody()) {
return Lists.newArrayList();
}
return JSONObject.parseArray(responseEntity.getBody(), RetrieveQueryResult.class);
}
private List<String> getCollectionList() {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = embeddingConfig.getUrl() + "/list_collections";
ResponseEntity<String> responseEntity = doRequest(url, null, HttpMethod.GET);
if (!responseEntity.hasBody()) {
return Lists.newArrayList();
}
List<EmbeddingCollection> embeddingCollections = JSONObject.parseArray(responseEntity.getBody(),
EmbeddingCollection.class);
return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList());
}
public ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) {
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(headers);
if (jsonBody != null) {
log.info("[embedding] request body :{}", jsonBody);
entity = new HttpEntity<>(jsonBody, headers);
}
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
httpMethod, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] url :{} result body:{}", url, responseEntity);
return responseEntity;
} catch (Throwable e) {
log.warn("doRequest service failed, url:" + url, e);
}
return ResponseEntity.of(Optional.empty());
}
}

View File

@@ -0,0 +1,47 @@
package dev.langchain4j.store.embedding;
import com.google.common.base.Objects;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import java.util.Map;
@Data
public class Retrieval {
protected String id;
protected double distance;
protected String query;
protected Map<String, Object> metadata;
public static Long getLongId(Object id) {
if (id == null || StringUtils.isBlank(id.toString())) {
return null;
}
String[] split = id.toString().split(Constants.UNDERLINE);
return Long.parseLong(split[0]);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Retrieval retrieval = (Retrieval) o;
return Double.compare(retrieval.distance, distance) == 0 && Objects.equal(id,
retrieval.id) && Objects.equal(query, retrieval.query)
&& Objects.equal(metadata, retrieval.metadata);
}
@Override
public int hashCode() {
return Objects.hashCode(id, distance, query, metadata);
}
}

View File

@@ -0,0 +1,20 @@
package dev.langchain4j.store.embedding;
import lombok.Builder;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
@Builder
public class RetrieveQuery {
private List<String> queryTextsList;
private Map<String, String> filterCondition;
private List<List<Double>> queryEmbeddings;
}

View File

@@ -0,0 +1,15 @@
package dev.langchain4j.store.embedding;
import lombok.Data;
import java.util.List;
@Data
public class RetrieveQueryResult {
private String query;
private List<Retrieval> retrieval;
}

View File

@@ -0,0 +1,19 @@
package dev.langchain4j.store.embedding;
import java.util.List;
/**
* Supersonic EmbeddingStore
* Enhanced the functionality by enabling the addition and querying of collection names.
*/
public interface S2EmbeddingStore {
void addCollection(String collectionName);
void addQuery(String collectionName, List<EmbeddingQuery> queries);
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
}