(improvement)(chat) Support integrating embeddingStore into system settings. (#1405)

This commit is contained in:
lexluo09
2024-07-14 21:50:12 +08:00
committed by GitHub
parent 4eb6193699
commit d5c78d87e7
9 changed files with 41 additions and 28 deletions

View File

@@ -15,7 +15,7 @@ import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
@Service("EmbeddingModelConfig") @Service("EmbeddingModelParameterConfig")
@Slf4j @Slf4j
public class EmbeddingModelParameterConfig extends ParameterConfig { public class EmbeddingModelParameterConfig extends ParameterConfig {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.pojo;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.io.Serializable; import java.io.Serializable;
@@ -11,6 +12,7 @@ import java.io.Serializable;
@Builder @Builder
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode
public class EmbeddingStoreConfig implements Serializable { public class EmbeddingStoreConfig implements Serializable {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private String provider; private String provider;

View File

@@ -14,6 +14,7 @@ import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
@@ -40,10 +41,6 @@ import java.util.stream.Collectors;
@Service @Service
@Slf4j @Slf4j
public class EmbeddingServiceImpl implements EmbeddingService { public class EmbeddingServiceImpl implements EmbeddingService {
@Autowired
private EmbeddingStoreFactory embeddingStoreFactory;
@Autowired @Autowired
private EmbeddingModelParameterConfig embeddingModelParameterConfig; private EmbeddingModelParameterConfig embeddingModelParameterConfig;
@@ -54,6 +51,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Override @Override
public void addQuery(String collectionName, List<TextSegment> queries) { public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
for (TextSegment query : queries) { for (TextSegment query : queries) {
@@ -102,6 +100,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Override @Override
public void deleteQuery(String collectionName, List<TextSegment> queries) { public void deleteQuery(String collectionName, List<TextSegment> queries) {
//Not supported yet in Milvus and Chroma //Not supported yet in Milvus and Chroma
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
try { try {
if (embeddingStore instanceof InMemoryEmbeddingStore) { if (embeddingStore instanceof InMemoryEmbeddingStore) {
@@ -122,7 +121,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Override @Override
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
List<RetrieveQueryResult> results = new ArrayList<>(); List<RetrieveQueryResult> results = new ArrayList<>();
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
List<String> queryTextsList = retrieveQuery.getQueryTextsList(); List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition(); Map<String, String> filterCondition = retrieveQuery.getFilterCondition();

View File

@@ -16,6 +16,6 @@ public class ChromaAutoConfig {
@Bean @Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url") @ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
EmbeddingStoreFactory chromaChatModel(Properties properties) { EmbeddingStoreFactory chromaChatModel(Properties properties) {
return new ChromaEmbeddingStoreFactory(properties); return new ChromaEmbeddingStoreFactory(properties.getEmbeddingStore());
} }
} }

View File

@@ -12,19 +12,18 @@ import java.time.Duration;
@Slf4j @Slf4j
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private Properties properties; private EmbeddingStoreProperties storeProperties;
public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig)); this(createPropertiesFromConfig(storeConfig));
} }
public ChromaEmbeddingStoreFactory(Properties properties) { public ChromaEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) {
this.properties = properties; this.storeProperties = storeProperties;
} }
@Override @Override
public EmbeddingStore createEmbeddingStore(String collectionName) { public EmbeddingStore createEmbeddingStore(String collectionName) {
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
return ChromaEmbeddingStore.builder() return ChromaEmbeddingStore.builder()
.baseUrl(storeProperties.getBaseUrl()) .baseUrl(storeProperties.getBaseUrl())
.collectionName(collectionName) .collectionName(collectionName)
@@ -32,12 +31,10 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
.build(); .build();
} }
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore); BeanUtils.copyProperties(storeConfig, embeddingStore);
embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut())); embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut()));
properties.setEmbeddingStore(embeddingStore); return embeddingStore;
return properties;
} }
} }

View File

@@ -2,11 +2,13 @@ package dev.langchain4j.chroma.spring;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.ToString;
import java.time.Duration; import java.time.Duration;
@Getter @Getter
@Setter @Setter
@ToString
public class EmbeddingStoreProperties { public class EmbeddingStoreProperties {
private String baseUrl; private String baseUrl;

View File

@@ -6,7 +6,7 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory { public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
protected static final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>(); protected final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
public EmbeddingStore<TextSegment> create(String collectionName) { public EmbeddingStore<TextSegment> create(String collectionName) {
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore); return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);

View File

@@ -1,5 +1,6 @@
package dev.langchain4j.store.embedding; package dev.langchain4j.store.embedding;
import com.tencent.supersonic.common.config.EmbeddingStoreParameterConfig;
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory; import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
@@ -7,20 +8,33 @@ import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory; import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class EmbeddingStoreFactoryProvider { public class EmbeddingStoreFactoryProvider {
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig storeConfig) { protected static final Map<EmbeddingStoreConfig, EmbeddingStoreFactory> factoryMap = new ConcurrentHashMap<>();
if (storeConfig == null || StringUtils.isBlank(storeConfig.getProvider())) {
public static EmbeddingStoreFactory getFactory() {
EmbeddingStoreParameterConfig parameterConfig = ContextUtils.getBean(EmbeddingStoreParameterConfig.class);
return getFactory(parameterConfig.convert());
}
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig embeddingStoreConfig) {
if (embeddingStoreConfig == null || StringUtils.isBlank(embeddingStoreConfig.getProvider())) {
return ContextUtils.getBean(EmbeddingStoreFactory.class); return ContextUtils.getBean(EmbeddingStoreFactory.class);
} }
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(storeConfig.getProvider())) { if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return new ChromaEmbeddingStoreFactory(storeConfig); return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig));
} }
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(storeConfig.getProvider())) { if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return new MilvusEmbeddingStoreFactory(storeConfig); return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
} }
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(storeConfig.getProvider())) { if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return new InMemoryEmbeddingStoreFactory(storeConfig); return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
} }
throw new RuntimeException("Unsupported EmbeddingStore provider: " + storeConfig.getProvider()); throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: " + embeddingStoreConfig.getProvider());
} }
} }

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.server.web.service.DimensionService;
import com.tencent.supersonic.headless.server.web.service.MetricService; import com.tencent.supersonic.headless.server.web.service.MetricService;
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
import dev.langchain4j.store.embedding.TextSegmentConvert; import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@@ -35,15 +36,13 @@ public class MetaEmbeddingTask implements CommandLineRunner {
@Autowired @Autowired
private DimensionService dimensionService; private DimensionService dimensionService;
@Autowired
private EmbeddingStoreFactory embeddingStoreFactory;
@PreDestroy @PreDestroy
public void onShutdown() { public void onShutdown() {
embeddingStorePersistFile(); embeddingStorePersistFile();
} }
private void embeddingStorePersistFile() { private void embeddingStorePersistFile() {
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) { if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) {
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
InMemoryEmbeddingStoreFactory inMemoryFactory = InMemoryEmbeddingStoreFactory inMemoryFactory =