[improvement](chat) default use AllMiniLmL6V2EmbeddingModel (#490)

This commit is contained in:
lexluo09
2023-12-11 11:01:47 +08:00
committed by GitHub
parent 899047dbd1
commit 97b11ec244

View File

@@ -19,6 +19,7 @@ import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel; import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiLanguageModel; import dev.langchain4j.model.openai.OpenAiLanguageModel;
import dev.langchain4j.model.openai.OpenAiModerationModel; import dev.langchain4j.model.openai.OpenAiModerationModel;
import java.util.Arrays;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -187,10 +188,10 @@ public class S2LangChain4jAutoConfiguration {
@ConditionalOnMissingBean @ConditionalOnMissingBean
@Primary @Primary
EmbeddingModel embeddingModel(LangChain4jProperties properties) { EmbeddingModel embeddingModel(LangChain4jProperties properties) {
if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n" if (properties.getEmbeddingModel() == null || !Arrays.stream(ModelProvider.values())
+ "langchain4j.embedding-model.provider = openai\n" .anyMatch(provider -> provider.equals(properties.getEmbeddingModel().getProvider()))) {
+ "langchain4j.embedding-model.openai.api-key = sk-...\n"); return new AllMiniLmL6V2EmbeddingModel();
} }
switch (properties.getEmbeddingModel().getProvider()) { switch (properties.getEmbeddingModel().getProvider()) {
@@ -242,13 +243,8 @@ public class S2LangChain4jAutoConfiguration {
.logRequests(localAi.getLogRequests()) .logRequests(localAi.getLogRequests())
.logResponses(localAi.getLogResponses()) .logResponses(localAi.getLogResponses())
.build(); .build();
case IN_MEMORY:
return new AllMiniLmL6V2EmbeddingModel();
default: default:
throw illegalConfiguration("Unsupported embedding model provider: %s", return new AllMiniLmL6V2EmbeddingModel();
properties.getEmbeddingModel().getProvider());
} }
} }