mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:28:12 +00:00
(improvement)(chat) Support integrating embeddingStore into system settings. (#1405)
This commit is contained in:
@@ -15,7 +15,7 @@ import org.springframework.stereotype.Service;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Service("EmbeddingModelConfig")
|
@Service("EmbeddingModelParameterConfig")
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.pojo;
|
|||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
@@ -11,6 +12,7 @@ import java.io.Serializable;
|
|||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode
|
||||||
public class EmbeddingStoreConfig implements Serializable {
|
public class EmbeddingStoreConfig implements Serializable {
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
private String provider;
|
private String provider;
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
|||||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
@@ -40,10 +41,6 @@ import java.util.stream.Collectors;
|
|||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingModelParameterConfig embeddingModelParameterConfig;
|
private EmbeddingModelParameterConfig embeddingModelParameterConfig;
|
||||||
|
|
||||||
@@ -54,6 +51,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||||
|
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||||
|
|
||||||
for (TextSegment query : queries) {
|
for (TextSegment query : queries) {
|
||||||
@@ -102,6 +100,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
@Override
|
@Override
|
||||||
public void deleteQuery(String collectionName, List<TextSegment> queries) {
|
public void deleteQuery(String collectionName, List<TextSegment> queries) {
|
||||||
//Not supported yet in Milvus and Chroma
|
//Not supported yet in Milvus and Chroma
|
||||||
|
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||||
try {
|
try {
|
||||||
if (embeddingStore instanceof InMemoryEmbeddingStore) {
|
if (embeddingStore instanceof InMemoryEmbeddingStore) {
|
||||||
@@ -122,7 +121,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
@Override
|
@Override
|
||||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||||
|
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||||
|
|||||||
@@ -16,6 +16,6 @@ public class ChromaAutoConfig {
|
|||||||
@Bean
|
@Bean
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
|
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
|
||||||
EmbeddingStoreFactory chromaChatModel(Properties properties) {
|
EmbeddingStoreFactory chromaChatModel(Properties properties) {
|
||||||
return new ChromaEmbeddingStoreFactory(properties);
|
return new ChromaEmbeddingStoreFactory(properties.getEmbeddingStore());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -12,19 +12,18 @@ import java.time.Duration;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||||
|
|
||||||
private Properties properties;
|
private EmbeddingStoreProperties storeProperties;
|
||||||
|
|
||||||
public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
||||||
this(createPropertiesFromConfig(storeConfig));
|
this(createPropertiesFromConfig(storeConfig));
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChromaEmbeddingStoreFactory(Properties properties) {
|
public ChromaEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) {
|
||||||
this.properties = properties;
|
this.storeProperties = storeProperties;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EmbeddingStore createEmbeddingStore(String collectionName) {
|
public EmbeddingStore createEmbeddingStore(String collectionName) {
|
||||||
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
|
||||||
return ChromaEmbeddingStore.builder()
|
return ChromaEmbeddingStore.builder()
|
||||||
.baseUrl(storeProperties.getBaseUrl())
|
.baseUrl(storeProperties.getBaseUrl())
|
||||||
.collectionName(collectionName)
|
.collectionName(collectionName)
|
||||||
@@ -32,12 +31,10 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
||||||
Properties properties = new Properties();
|
|
||||||
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
||||||
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
||||||
embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut()));
|
embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut()));
|
||||||
properties.setEmbeddingStore(embeddingStore);
|
return embeddingStore;
|
||||||
return properties;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2,11 +2,13 @@ package dev.langchain4j.chroma.spring;
|
|||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
|
@ToString
|
||||||
public class EmbeddingStoreProperties {
|
public class EmbeddingStoreProperties {
|
||||||
|
|
||||||
private String baseUrl;
|
private String baseUrl;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import java.util.Map;
|
|||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||||
protected static final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
|
protected final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
public EmbeddingStore<TextSegment> create(String collectionName) {
|
public EmbeddingStore<TextSegment> create(String collectionName) {
|
||||||
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);
|
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package dev.langchain4j.store.embedding;
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingStoreParameterConfig;
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
||||||
@@ -7,20 +8,33 @@ import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
|||||||
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
|
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
public class EmbeddingStoreFactoryProvider {
|
public class EmbeddingStoreFactoryProvider {
|
||||||
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig storeConfig) {
|
protected static final Map<EmbeddingStoreConfig, EmbeddingStoreFactory> factoryMap = new ConcurrentHashMap<>();
|
||||||
if (storeConfig == null || StringUtils.isBlank(storeConfig.getProvider())) {
|
|
||||||
|
public static EmbeddingStoreFactory getFactory() {
|
||||||
|
EmbeddingStoreParameterConfig parameterConfig = ContextUtils.getBean(EmbeddingStoreParameterConfig.class);
|
||||||
|
return getFactory(parameterConfig.convert());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig embeddingStoreConfig) {
|
||||||
|
if (embeddingStoreConfig == null || StringUtils.isBlank(embeddingStoreConfig.getProvider())) {
|
||||||
return ContextUtils.getBean(EmbeddingStoreFactory.class);
|
return ContextUtils.getBean(EmbeddingStoreFactory.class);
|
||||||
}
|
}
|
||||||
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||||
return new ChromaEmbeddingStoreFactory(storeConfig);
|
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||||
|
storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig));
|
||||||
}
|
}
|
||||||
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||||
return new MilvusEmbeddingStoreFactory(storeConfig);
|
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||||
|
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
|
||||||
}
|
}
|
||||||
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||||
return new InMemoryEmbeddingStoreFactory(storeConfig);
|
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||||
|
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
|
||||||
}
|
}
|
||||||
throw new RuntimeException("Unsupported EmbeddingStore provider: " + storeConfig.getProvider());
|
throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: " + embeddingStoreConfig.getProvider());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.server.web.service.DimensionService;
|
|||||||
import com.tencent.supersonic.headless.server.web.service.MetricService;
|
import com.tencent.supersonic.headless.server.web.service.MetricService;
|
||||||
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
|
||||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
@@ -35,15 +36,13 @@ public class MetaEmbeddingTask implements CommandLineRunner {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private DimensionService dimensionService;
|
private DimensionService dimensionService;
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
|
||||||
|
|
||||||
@PreDestroy
|
@PreDestroy
|
||||||
public void onShutdown() {
|
public void onShutdown() {
|
||||||
embeddingStorePersistFile();
|
embeddingStorePersistFile();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void embeddingStorePersistFile() {
|
private void embeddingStorePersistFile() {
|
||||||
|
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||||
if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) {
|
if (embeddingStoreFactory instanceof InMemoryEmbeddingStoreFactory) {
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
InMemoryEmbeddingStoreFactory inMemoryFactory =
|
InMemoryEmbeddingStoreFactory inMemoryFactory =
|
||||||
|
|||||||
Reference in New Issue
Block a user