(improvement)(chat) Remove initial default values, make the configuration file settings take effect, and optimize the code. (#1406)

This commit is contained in:
lexluo09
2024-07-14 23:20:06 +08:00
committed by GitHub
parent d5c78d87e7
commit 529251097b
8 changed files with 25 additions and 33 deletions

View File

@@ -20,7 +20,7 @@ import java.util.List;
public class EmbeddingModelParameterConfig extends ParameterConfig { public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_PROVIDER = public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, new Parameter("s2.embedding.model.provider", "",
"接口协议", "", "接口协议", "",
"string", "向量模型配置", "string", "向量模型配置",
Lists.newArrayList(InMemoryModelFactory.PROVIDER, Lists.newArrayList(InMemoryModelFactory.PROVIDER,

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.common.config;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.InMemoryModelFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -13,7 +12,7 @@ import java.util.List;
@Slf4j @Slf4j
public class EmbeddingStoreParameterConfig extends ParameterConfig { public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER = public static final Parameter EMBEDDING_STORE_PROVIDER =
new Parameter("s2.embedding.store.provider", InMemoryModelFactory.PROVIDER, new Parameter("s2.embedding.store.provider", "",
"向量库类型", "", "向量库类型", "",
"string", "向量库配置"); "string", "向量库配置");

View File

@@ -2,8 +2,6 @@ package com.tencent.supersonic.common.service.impl;
import com.google.common.cache.Cache; import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheBuilder;
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
@@ -26,7 +24,6 @@ import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
@@ -41,8 +38,6 @@ import java.util.stream.Collectors;
@Service @Service
@Slf4j @Slf4j
public class EmbeddingServiceImpl implements EmbeddingService { public class EmbeddingServiceImpl implements EmbeddingService {
@Autowired
private EmbeddingModelParameterConfig embeddingModelParameterConfig;
private Cache<String, Boolean> cache = CacheBuilder.newBuilder() private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
.maximumSize(10000) .maximumSize(10000)
@@ -57,7 +52,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
for (TextSegment query : queries) { for (TextSegment query : queries) {
String question = query.text(); String question = query.text();
try { try {
EmbeddingModel embeddingModel = getEmbeddingModel(); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
Embedding embedding = embeddingModel.embed(question).content(); Embedding embedding = embeddingModel.embed(question).content();
boolean existSegment = existSegment(embeddingStore, query, embedding); boolean existSegment = existSegment(embeddingStore, query, embedding);
if (existSegment) { if (existSegment) {
@@ -126,7 +121,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
List<String> queryTextsList = retrieveQuery.getQueryTextsList(); List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition(); Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) { for (String queryText : queryTextsList) {
EmbeddingModel embeddingModel = getEmbeddingModel(); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
Embedding embeddedText = embeddingModel.embed(queryText).content(); Embedding embeddedText = embeddingModel.embed(queryText).content();
Filter filter = createCombinedFilter(filterCondition); Filter filter = createCombinedFilter(filterCondition);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
@@ -174,9 +169,4 @@ public class EmbeddingServiceImpl implements EmbeddingService {
} }
return result; return result;
} }
private EmbeddingModel getEmbeddingModel() {
EmbeddingModelConfig embeddingModelConfig = embeddingModelParameterConfig.convert();
return ModelProvider.getEmbeddingModel(embeddingModelConfig);
}
} }

View File

@@ -24,7 +24,7 @@ public class InMemoryAutoConfig {
@Bean @Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.persist-path") @ConditionalOnProperty(PREFIX + ".embedding-store.persist-path")
EmbeddingStoreFactory inMemoryChatModel(Properties properties) { EmbeddingStoreFactory inMemoryChatModel(Properties properties) {
return new InMemoryEmbeddingStoreFactory(properties); return new InMemoryEmbeddingStoreFactory(properties.getEmbeddingStore());
} }
@Bean @Bean

View File

@@ -23,22 +23,20 @@ import java.util.concurrent.CopyOnWriteArraySet;
public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
public static final String PERSISTENT_FILE_PRE = "InMemory."; public static final String PERSISTENT_FILE_PRE = "InMemory.";
private Properties properties; private EmbeddingStoreProperties embeddingStore;
public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig)); this(createPropertiesFromConfig(storeConfig));
} }
public InMemoryEmbeddingStoreFactory(Properties properties) { public InMemoryEmbeddingStoreFactory(EmbeddingStoreProperties embeddingStore) {
this.properties = properties; this.embeddingStore = embeddingStore;
} }
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore); BeanUtils.copyProperties(storeConfig, embeddingStore);
properties.setEmbeddingStore(embeddingStore); return embeddingStore;
return properties;
} }
@Override @Override
@@ -97,7 +95,7 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private Path getPersistPath(String collectionName) { private Path getPersistPath(String collectionName) {
String persistFile = PERSISTENT_FILE_PRE + collectionName; String persistFile = PERSISTENT_FILE_PRE + collectionName;
String persistPath = properties.getEmbeddingStore().getPersistPath(); String persistPath = embeddingStore.getPersistPath();
if (StringUtils.isEmpty(persistPath)) { if (StringUtils.isEmpty(persistPath)) {
return null; return null;
} }

View File

@@ -15,6 +15,6 @@ public class MilvusAutoConfig {
@Bean @Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.uri") @ConditionalOnProperty(PREFIX + ".embedding-store.uri")
EmbeddingStoreFactory milvusChatModel(Properties properties) { EmbeddingStoreFactory milvusChatModel(Properties properties) {
return new MilvusEmbeddingStoreFactory(properties); return new MilvusEmbeddingStoreFactory(properties.getEmbeddingStore());
} }
} }

View File

@@ -8,28 +8,25 @@ import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private final Properties properties; private final EmbeddingStoreProperties storeProperties;
public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig)); this(createPropertiesFromConfig(storeConfig));
} }
public MilvusEmbeddingStoreFactory(Properties properties) { public MilvusEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) {
this.properties = properties; this.storeProperties = storeProperties;
} }
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore); BeanUtils.copyProperties(storeConfig, embeddingStore);
embeddingStore.setUri(storeConfig.getBaseUrl()); embeddingStore.setUri(storeConfig.getBaseUrl());
properties.setEmbeddingStore(embeddingStore); return embeddingStore;
return properties;
} }
@Override @Override
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) { public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
return MilvusEmbeddingStore.builder() return MilvusEmbeddingStore.builder()
.host(storeProperties.getHost()) .host(storeProperties.getHost())
.port(storeProperties.getPort()) .port(storeProperties.getPort())

View File

@@ -1,5 +1,6 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -31,6 +32,13 @@ public class ModelProvider {
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider()); throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
} }
public static EmbeddingModel getEmbeddingModel() {
EmbeddingModelParameterConfig parameterConfig = ContextUtils.getBean(
EmbeddingModelParameterConfig.class);
EmbeddingModelConfig embeddingModelConfig = parameterConfig.convert();
return getEmbeddingModel(embeddingModelConfig);
}
public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) { public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) {
if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) { if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) {
return ContextUtils.getBean(EmbeddingModel.class); return ContextUtils.getBean(EmbeddingModel.class);