(improvement)(chat) Support configuring embeddingModel or embeddingStore at the agent level. (#1361)

This commit is contained in:
lexluo09
2024-07-06 20:44:23 +08:00
committed by GitHub
parent d39db734c4
commit 6db6aaf98d
42 changed files with 669 additions and 299 deletions

View File

@@ -0,0 +1,65 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
public class ModelProvider {
private static final Map<String, ModelFactory> factories = new HashMap<>();
public static void add(Provider provider, ModelFactory modelFactory) {
factories.put(provider.name(), modelFactory);
}
public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createChatModel(llmConfig);
}
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider());
}
public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) {
if (modelConfig == null || modelConfig.getChatModel() == null
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ChatModelConfig chatModel = modelConfig.getChatModel();
ModelFactory modelFactory = factories.get(chatModel.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createChatModel(chatModel);
}
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
}
public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) {
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
return ContextUtils.getBean(EmbeddingModel.class);
}
EmbeddingModelConfig embeddingModel = modelConfig.getEmbeddingModel();
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createEmbeddingModel(embeddingModel);
}
throw new RuntimeException("Unsupported EmbeddingModel provider: " + embeddingModel.getProvider());
}
}