mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(chat) Remove initial default values, make the configuration file settings take effect, and optimize the code. (#1406)
This commit is contained in:
@@ -20,7 +20,7 @@ import java.util.List;
|
|||||||
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
||||||
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
|
new Parameter("s2.embedding.model.provider", "",
|
||||||
"接口协议", "",
|
"接口协议", "",
|
||||||
"string", "向量模型配置",
|
"string", "向量模型配置",
|
||||||
Lists.newArrayList(InMemoryModelFactory.PROVIDER,
|
Lists.newArrayList(InMemoryModelFactory.PROVIDER,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.common.config;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Parameter;
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@@ -13,7 +12,7 @@ import java.util.List;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||||
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
||||||
new Parameter("s2.embedding.store.provider", InMemoryModelFactory.PROVIDER,
|
new Parameter("s2.embedding.store.provider", "",
|
||||||
"向量库类型", "",
|
"向量库类型", "",
|
||||||
"string", "向量库配置");
|
"string", "向量库配置");
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package com.tencent.supersonic.common.service.impl;
|
|||||||
|
|
||||||
import com.google.common.cache.Cache;
|
import com.google.common.cache.Cache;
|
||||||
import com.google.common.cache.CacheBuilder;
|
import com.google.common.cache.CacheBuilder;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
@@ -26,7 +24,6 @@ import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.MapUtils;
|
import org.apache.commons.collections.MapUtils;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -41,8 +38,6 @@ import java.util.stream.Collectors;
|
|||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||||
@Autowired
|
|
||||||
private EmbeddingModelParameterConfig embeddingModelParameterConfig;
|
|
||||||
|
|
||||||
private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
|
private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
|
||||||
.maximumSize(10000)
|
.maximumSize(10000)
|
||||||
@@ -57,7 +52,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
for (TextSegment query : queries) {
|
for (TextSegment query : queries) {
|
||||||
String question = query.text();
|
String question = query.text();
|
||||||
try {
|
try {
|
||||||
EmbeddingModel embeddingModel = getEmbeddingModel();
|
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
||||||
Embedding embedding = embeddingModel.embed(question).content();
|
Embedding embedding = embeddingModel.embed(question).content();
|
||||||
boolean existSegment = existSegment(embeddingStore, query, embedding);
|
boolean existSegment = existSegment(embeddingStore, query, embedding);
|
||||||
if (existSegment) {
|
if (existSegment) {
|
||||||
@@ -126,7 +121,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||||
for (String queryText : queryTextsList) {
|
for (String queryText : queryTextsList) {
|
||||||
EmbeddingModel embeddingModel = getEmbeddingModel();
|
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
||||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||||
Filter filter = createCombinedFilter(filterCondition);
|
Filter filter = createCombinedFilter(filterCondition);
|
||||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||||
@@ -174,9 +169,4 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private EmbeddingModel getEmbeddingModel() {
|
|
||||||
EmbeddingModelConfig embeddingModelConfig = embeddingModelParameterConfig.convert();
|
|
||||||
return ModelProvider.getEmbeddingModel(embeddingModelConfig);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ public class InMemoryAutoConfig {
|
|||||||
@Bean
|
@Bean
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
|
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
|
||||||
EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
|
EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
|
||||||
return new InMemoryEmbeddingStoreFactory(properties);
|
return new InMemoryEmbeddingStoreFactory(properties.getEmbeddingStore());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
|
|||||||
@@ -23,22 +23,20 @@ import java.util.concurrent.CopyOnWriteArraySet;
|
|||||||
public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||||
|
|
||||||
public static final String PERSISTENT_FILE_PRE = "InMemory.";
|
public static final String PERSISTENT_FILE_PRE = "InMemory.";
|
||||||
private Properties properties;
|
private EmbeddingStoreProperties embeddingStore;
|
||||||
|
|
||||||
public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
||||||
this(createPropertiesFromConfig(storeConfig));
|
this(createPropertiesFromConfig(storeConfig));
|
||||||
}
|
}
|
||||||
|
|
||||||
public InMemoryEmbeddingStoreFactory(Properties properties) {
|
public InMemoryEmbeddingStoreFactory(EmbeddingStoreProperties embeddingStore) {
|
||||||
this.properties = properties;
|
this.embeddingStore = embeddingStore;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
||||||
Properties properties = new Properties();
|
|
||||||
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
||||||
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
||||||
properties.setEmbeddingStore(embeddingStore);
|
return embeddingStore;
|
||||||
return properties;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -97,7 +95,7 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
|||||||
|
|
||||||
private Path getPersistPath(String collectionName) {
|
private Path getPersistPath(String collectionName) {
|
||||||
String persistFile = PERSISTENT_FILE_PRE + collectionName;
|
String persistFile = PERSISTENT_FILE_PRE + collectionName;
|
||||||
String persistPath = properties.getEmbeddingStore().getPersistPath();
|
String persistPath = embeddingStore.getPersistPath();
|
||||||
if (StringUtils.isEmpty(persistPath)) {
|
if (StringUtils.isEmpty(persistPath)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,6 @@ public class MilvusAutoConfig {
|
|||||||
@Bean
|
@Bean
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-store.uri")
|
@ConditionalOnProperty(PREFIX + ".embedding-store.uri")
|
||||||
EmbeddingStoreFactory milvusChatModel(Properties properties) {
|
EmbeddingStoreFactory milvusChatModel(Properties properties) {
|
||||||
return new MilvusEmbeddingStoreFactory(properties);
|
return new MilvusEmbeddingStoreFactory(properties.getEmbeddingStore());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -8,28 +8,25 @@ import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
|||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||||
private final Properties properties;
|
private final EmbeddingStoreProperties storeProperties;
|
||||||
|
|
||||||
public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
||||||
this(createPropertiesFromConfig(storeConfig));
|
this(createPropertiesFromConfig(storeConfig));
|
||||||
}
|
}
|
||||||
|
|
||||||
public MilvusEmbeddingStoreFactory(Properties properties) {
|
public MilvusEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) {
|
||||||
this.properties = properties;
|
this.storeProperties = storeProperties;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
||||||
Properties properties = new Properties();
|
|
||||||
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
||||||
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
||||||
embeddingStore.setUri(storeConfig.getBaseUrl());
|
embeddingStore.setUri(storeConfig.getBaseUrl());
|
||||||
properties.setEmbeddingStore(embeddingStore);
|
return embeddingStore;
|
||||||
return properties;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
||||||
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
|
||||||
return MilvusEmbeddingStore.builder()
|
return MilvusEmbeddingStore.builder()
|
||||||
.host(storeProperties.getHost())
|
.host(storeProperties.getHost())
|
||||||
.port(storeProperties.getPort())
|
.port(storeProperties.getPort())
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -31,6 +32,13 @@ public class ModelProvider {
|
|||||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static EmbeddingModel getEmbeddingModel() {
|
||||||
|
EmbeddingModelParameterConfig parameterConfig = ContextUtils.getBean(
|
||||||
|
EmbeddingModelParameterConfig.class);
|
||||||
|
EmbeddingModelConfig embeddingModelConfig = parameterConfig.convert();
|
||||||
|
return getEmbeddingModel(embeddingModelConfig);
|
||||||
|
}
|
||||||
|
|
||||||
public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||||
if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) {
|
if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) {
|
||||||
return ContextUtils.getBean(EmbeddingModel.class);
|
return ContextUtils.getBean(EmbeddingModel.class);
|
||||||
|
|||||||
Reference in New Issue
Block a user