mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) Support integrating embeddingStore into system settings. (#1405)
This commit is contained in:
@@ -15,7 +15,7 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("EmbeddingModelConfig")
|
||||
@Service("EmbeddingModelParameterConfig")
|
||||
@Slf4j
|
||||
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.pojo;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
@@ -11,6 +12,7 @@ import java.io.Serializable;
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode
|
||||
public class EmbeddingStoreConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private String provider;
|
||||
|
||||
@@ -14,6 +14,7 @@ import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
@@ -40,10 +41,6 @@ import java.util.stream.Collectors;
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
|
||||
@Autowired
|
||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingModelParameterConfig embeddingModelParameterConfig;
|
||||
|
||||
@@ -54,6 +51,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
|
||||
@Override
|
||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
|
||||
for (TextSegment query : queries) {
|
||||
@@ -102,6 +100,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
@Override
|
||||
public void deleteQuery(String collectionName, List<TextSegment> queries) {
|
||||
//Not supported yet in Milvus and Chroma
|
||||
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
try {
|
||||
if (embeddingStore instanceof InMemoryEmbeddingStore) {
|
||||
@@ -122,7 +121,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
@Override
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||
|
||||
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
|
||||
@@ -16,6 +16,6 @@ public class ChromaAutoConfig {
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
|
||||
EmbeddingStoreFactory chromaChatModel(Properties properties) {
|
||||
return new ChromaEmbeddingStoreFactory(properties);
|
||||
return new ChromaEmbeddingStoreFactory(properties.getEmbeddingStore());
|
||||
}
|
||||
}
|
||||
@@ -12,19 +12,18 @@ import java.time.Duration;
|
||||
@Slf4j
|
||||
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
|
||||
private Properties properties;
|
||||
private EmbeddingStoreProperties storeProperties;
|
||||
|
||||
public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
||||
this(createPropertiesFromConfig(storeConfig));
|
||||
}
|
||||
|
||||
public ChromaEmbeddingStoreFactory(Properties properties) {
|
||||
this.properties = properties;
|
||||
public ChromaEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) {
|
||||
this.storeProperties = storeProperties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingStore createEmbeddingStore(String collectionName) {
|
||||
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
||||
return ChromaEmbeddingStore.builder()
|
||||
.baseUrl(storeProperties.getBaseUrl())
|
||||
.collectionName(collectionName)
|
||||
@@ -32,12 +31,10 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
||||
Properties properties = new Properties();
|
||||
private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
||||
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
||||
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
||||
embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut()));
|
||||
properties.setEmbeddingStore(embeddingStore);
|
||||
return properties;
|
||||
return embeddingStore;
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,13 @@ package dev.langchain4j.chroma.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@ToString
|
||||
public class EmbeddingStoreProperties {
|
||||
|
||||
private String baseUrl;
|
||||
|
||||
@@ -6,7 +6,7 @@ import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
protected static final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
|
||||
protected final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
|
||||
|
||||
public EmbeddingStore<TextSegment> create(String collectionName) {
|
||||
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
||||
@@ -7,20 +8,33 @@ import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
||||
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class EmbeddingStoreFactoryProvider {
|
||||
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig storeConfig) {
|
||||
if (storeConfig == null || StringUtils.isBlank(storeConfig.getProvider())) {
|
||||
protected static final Map<EmbeddingStoreConfig, EmbeddingStoreFactory> factoryMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static EmbeddingStoreFactory getFactory() {
|
||||
EmbeddingStoreParameterConfig parameterConfig = ContextUtils.getBean(EmbeddingStoreParameterConfig.class);
|
||||
return getFactory(parameterConfig.convert());
|
||||
}
|
||||
|
||||
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig embeddingStoreConfig) {
|
||||
if (embeddingStoreConfig == null || StringUtils.isBlank(embeddingStoreConfig.getProvider())) {
|
||||
return ContextUtils.getBean(EmbeddingStoreFactory.class);
|
||||
}
|
||||
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
||||
return new ChromaEmbeddingStoreFactory(storeConfig);
|
||||
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||
storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig));
|
||||
}
|
||||
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
||||
return new MilvusEmbeddingStoreFactory(storeConfig);
|
||||
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
|
||||
}
|
||||
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
||||
return new InMemoryEmbeddingStoreFactory(storeConfig);
|
||||
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
|
||||
}
|
||||
throw new RuntimeException("Unsupported EmbeddingStore provider: " + storeConfig.getProvider());
|
||||
throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: " + embeddingStoreConfig.getProvider());
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.server.web.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.web.service.MetricService;
|
||||
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
|
||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -35,15 +36,13 @@ public class MetaEmbeddingTask implements CommandLineRunner {
|
||||
@Autowired
|
||||
private DimensionService dimensionService;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
||||
|
||||
@PreDestroy
|
||||
public void onShutdown() {
|
||||
embeddingStorePersistFile();
|
||||
}
|
||||
|
||||
private void embeddingStorePersistFile() {
|
||||
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||
if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
InMemoryEmbeddingStoreFactory inMemoryFactory =
|
||||
|
||||
Reference in New Issue
Block a user