(improvement)(chat) Upgrade and optimize the embedding metastore. (#1198)

This commit is contained in:
lexluo09
2024-06-23 21:46:10 +08:00
committed by GitHub
parent 2ae94fb38c
commit 4d6cbf31f7
46 changed files with 3788 additions and 498 deletions

View File

@@ -2,12 +2,11 @@ package com.tencent.supersonic.headless.chat.knowledge;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Constants;
import dev.langchain4j.store.embedding.ComponentFactory;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
@@ -26,7 +25,8 @@ import org.springframework.stereotype.Service;
@Slf4j
public class MetaEmbeddingService {
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Autowired
private EmbeddingService embeddingService;
@Autowired
private EmbeddingConfig embeddingConfig;
@@ -42,7 +42,7 @@ public class MetaEmbeddingService {
}
String collectionName = embeddingConfig.getMetaCollectionName();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery, num);
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(collectionName, retrieveQuery, num);
if (CollectionUtils.isEmpty(resultList)) {
return new ArrayList<>();
}

View File

@@ -3,19 +3,12 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import dev.langchain4j.store.embedding.ComponentFactory;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.S2EmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
@@ -24,6 +17,11 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
@Slf4j
@Component
@@ -31,7 +29,8 @@ public class ExemplarManager {
private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json";
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Autowired
private EmbeddingService embeddingService;
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
};
@@ -56,7 +55,7 @@ public class ExemplarManager {
embeddingQuery.setMetadata(metaDataMap);
queries.add(embeddingQuery);
}
s2EmbeddingStore.addQuery(collectionName, queries);
embeddingService.addQuery(collectionName, queries);
}
public List<Map<String, String>> recallExemplars(String queryText, int maxResults) {
@@ -64,7 +63,7 @@ public class ExemplarManager {
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
.queryEmbeddings(null).build();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery,
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(collectionName, retrieveQuery,
maxResults);
List<Map<String, String>> result = new ArrayList<>();
if (CollectionUtils.isEmpty(resultList)) {