(improvement)(chat) Support loading local embedding models through the in-memory configuration method (#1201)

This commit is contained in:
lexluo09
2024-06-23 22:57:35 +08:00
committed by GitHub
parent 15ceca3102
commit a7d367baa3
25 changed files with 56 additions and 3464 deletions

View File

@@ -1,9 +1,7 @@
package com.tencent.supersonic.common.service.impl;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingQuery;
@@ -24,7 +22,6 @@ import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -34,6 +31,9 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Autowired
private EmbeddingStoreFactory embeddingStoreFactory;
@Autowired
private EmbeddingModel embeddingModel;
public synchronized void addCollection(String collectionName) {
embeddingStoreFactory.create(collectionName);
}
@@ -41,7 +41,6 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Override
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
EmbeddingModel embeddingModel = getEmbeddingModel();
for (EmbeddingQuery query : queries) {
String question = query.getQuery();
Embedding embedding = embeddingModel.embed(question).content();
@@ -49,26 +48,15 @@ public class EmbeddingServiceImpl implements EmbeddingService {
}
}
private static EmbeddingModel getEmbeddingModel() {
EmbeddingModel embeddingModel;
try {
embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
} catch (NoSuchBeanDefinitionException e) {
embeddingModel = new BgeSmallZhEmbeddingModel();
}
return embeddingModel;
}
@Override
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
}
@Override
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
EmbeddingModel embeddingModel = getEmbeddingModel();
List<RetrieveQueryResult> results = new ArrayList<>();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) {
@@ -110,7 +98,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
private static Filter createCombinedFilter(Map<String, String> map) {
Filter result = null;
if (Objects.isNull(map)) {
if (MapUtils.isEmpty(map)) {
return null;
}
for (Map.Entry<String, String> entry : map.entrySet()) {