mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-19 00:37:08 +00:00
(improvement)(chat) Support loading local embedding models through the in-memory configuration method (#1201)
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
package com.tencent.supersonic.common.service.impl;
|
||||
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
||||
@@ -24,7 +22,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -34,6 +31,9 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
@Autowired
|
||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingModel embeddingModel;
|
||||
|
||||
public synchronized void addCollection(String collectionName) {
|
||||
embeddingStoreFactory.create(collectionName);
|
||||
}
|
||||
@@ -41,7 +41,6 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
@Override
|
||||
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel();
|
||||
for (EmbeddingQuery query : queries) {
|
||||
String question = query.getQuery();
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
@@ -49,26 +48,15 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
}
|
||||
}
|
||||
|
||||
private static EmbeddingModel getEmbeddingModel() {
|
||||
EmbeddingModel embeddingModel;
|
||||
try {
|
||||
embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
|
||||
} catch (NoSuchBeanDefinitionException e) {
|
||||
embeddingModel = new BgeSmallZhEmbeddingModel();
|
||||
}
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel();
|
||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
for (String queryText : queryTextsList) {
|
||||
@@ -110,7 +98,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
|
||||
private static Filter createCombinedFilter(Map<String, String> map) {
|
||||
Filter result = null;
|
||||
if (Objects.isNull(map)) {
|
||||
if (MapUtils.isEmpty(map)) {
|
||||
return null;
|
||||
}
|
||||
for (Map.Entry<String, String> entry : map.entrySet()) {
|
||||
|
||||
Reference in New Issue
Block a user