[improvement](chat) Add an in_process provider and support offline loading of local embedding models. (#505)

This commit is contained in:
lexluo09
2023-12-14 14:16:03 +08:00
committed by GitHub
parent 169262cc62
commit 287a6561ff
9 changed files with 292 additions and 55 deletions

View File

@@ -7,6 +7,7 @@ import static dev.langchain4j.internal.Utils.isNullOrBlank;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
@@ -19,7 +20,6 @@ import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiLanguageModel;
import dev.langchain4j.model.openai.OpenAiModerationModel;
import java.util.Arrays;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -29,16 +29,16 @@ import org.springframework.context.annotation.Lazy;
import org.springframework.context.annotation.Primary;
@Configuration
@EnableConfigurationProperties(LangChain4jProperties.class)
@EnableConfigurationProperties(S2LangChain4jProperties.class)
public class S2LangChain4jAutoConfiguration {
@Autowired
private LangChain4jProperties properties;
private S2LangChain4jProperties properties;
@Bean
@Lazy
@ConditionalOnMissingBean
ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
ChatLanguageModel chatLanguageModel(S2LangChain4jProperties properties) {
if (properties.getChatModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n"
+ "langchain4j.chat-model.provider = openai\n"
@@ -113,7 +113,7 @@ public class S2LangChain4jAutoConfiguration {
@Bean
@Lazy
@ConditionalOnMissingBean
LanguageModel languageModel(LangChain4jProperties properties) {
LanguageModel languageModel(S2LangChain4jProperties properties) {
if (properties.getLanguageModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n"
+ "langchain4j.language-model.provider = openai\n"
@@ -187,11 +187,12 @@ public class S2LangChain4jAutoConfiguration {
@Lazy
@ConditionalOnMissingBean
@Primary
EmbeddingModel embeddingModel(LangChain4jProperties properties) {
EmbeddingModel embeddingModel(S2LangChain4jProperties properties) {
if (properties.getEmbeddingModel() == null || !Arrays.stream(ModelProvider.values())
.anyMatch(provider -> provider.equals(properties.getEmbeddingModel().getProvider()))) {
return new AllMiniLmL6V2EmbeddingModel();
if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n"
+ "langchain4j.embedding-model.provider = openai\n"
+ "langchain4j.embedding-model.openai.api-key = sk-...\n");
}
switch (properties.getEmbeddingModel().getProvider()) {
@@ -243,15 +244,23 @@ public class S2LangChain4jAutoConfiguration {
.logRequests(localAi.getLogRequests())
.logResponses(localAi.getLogResponses())
.build();
case IN_PROCESS:
InProcess inProcess = properties.getEmbeddingModel().getInProcess();
if (isNullOrBlank(inProcess.getModelPath())) {
return new AllMiniLmL6V2EmbeddingModel();
}
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());
default:
return new AllMiniLmL6V2EmbeddingModel();
throw illegalConfiguration("Unsupported embedding model provider: %s",
properties.getEmbeddingModel().getProvider());
}
}
@Bean
@Lazy
@ConditionalOnMissingBean
ModerationModel moderationModel(LangChain4jProperties properties) {
ModerationModel moderationModel(S2LangChain4jProperties properties) {
if (properties.getModerationModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n"
+ "langchain4j.moderation-model.provider = openai\n"