diff --git a/common/pom.xml b/common/pom.xml index 169ffd3cb..97d3ae7d4 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -186,6 +186,10 @@ log4j-api ${apache.log4j.version} + + dev.langchain4j + langchain4j-embeddings + diff --git a/launchers/common/src/main/java/dev/langchain4j/InProcess.java b/launchers/common/src/main/java/dev/langchain4j/InProcess.java new file mode 100644 index 000000000..790eb39a1 --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/InProcess.java @@ -0,0 +1,30 @@ +package dev.langchain4j; + +class InProcess { + + /*** + * the model local path + */ + private String modelPath; + + /*** + * the model's vocabulary local path + */ + private String vocabularyPath; + + public String getModelPath() { + return modelPath; + } + + public void setModelPath(String modelPath) { + this.modelPath = modelPath; + } + + public String getVocabularyPath() { + return vocabularyPath; + } + + public void setVocabularyPath(String vocabularyPath) { + this.vocabularyPath = vocabularyPath; + } +} \ No newline at end of file diff --git a/launchers/common/src/main/java/dev/langchain4j/S2EmbeddingModel.java b/launchers/common/src/main/java/dev/langchain4j/S2EmbeddingModel.java new file mode 100644 index 000000000..94ffbe0ea --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/S2EmbeddingModel.java @@ -0,0 +1,58 @@ +package dev.langchain4j; + +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +class S2EmbeddingModel { + + @NestedConfigurationProperty + private S2ModelProvider provider; + @NestedConfigurationProperty + private OpenAi openAi; + @NestedConfigurationProperty + private HuggingFace huggingFace; + @NestedConfigurationProperty + private LocalAi localAi; + + @NestedConfigurationProperty + private InProcess inProcess; + + public S2ModelProvider getProvider() { + return provider; + } + + public void setProvider(S2ModelProvider provider) { + this.provider = provider; + } + + public OpenAi getOpenAi() { + return openAi; + } + + public void setOpenAi(OpenAi openAi) { + this.openAi = openAi; + } + + public HuggingFace getHuggingFace() { + return huggingFace; + } + + public void setHuggingFace(HuggingFace huggingFace) { + this.huggingFace = huggingFace; + } + + public LocalAi getLocalAi() { + return localAi; + } + + public void setLocalAi(LocalAi localAi) { + this.localAi = localAi; + } + + public InProcess getInProcess() { + return inProcess; + } + + public void setInProcess(InProcess inProcess) { + this.inProcess = inProcess; + } +} \ No newline at end of file diff --git a/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java b/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java index 27ab0b846..10ae8aa50 100644 --- a/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java +++ b/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java @@ -7,6 +7,7 @@ import static dev.langchain4j.internal.Utils.isNullOrBlank; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel; import dev.langchain4j.model.huggingface.HuggingFaceChatModel; import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel; import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel; @@ -19,7 +20,6 @@ import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiEmbeddingModel; import dev.langchain4j.model.openai.OpenAiLanguageModel; import dev.langchain4j.model.openai.OpenAiModerationModel; -import java.util.Arrays; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; @@ -29,16 +29,16 @@ import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Primary; @Configuration -@EnableConfigurationProperties(LangChain4jProperties.class) +@EnableConfigurationProperties(S2LangChain4jProperties.class) public class S2LangChain4jAutoConfiguration { @Autowired - private LangChain4jProperties properties; + private S2LangChain4jProperties properties; @Bean @Lazy @ConditionalOnMissingBean - ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) { + ChatLanguageModel chatLanguageModel(S2LangChain4jProperties properties) { if (properties.getChatModel() == null) { throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n" + "langchain4j.chat-model.provider = openai\n" @@ -113,7 +113,7 @@ public class S2LangChain4jAutoConfiguration { @Bean @Lazy @ConditionalOnMissingBean - LanguageModel languageModel(LangChain4jProperties properties) { + LanguageModel languageModel(S2LangChain4jProperties properties) { if (properties.getLanguageModel() == null) { throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n" + "langchain4j.language-model.provider = openai\n" @@ -187,11 +187,12 @@ public class S2LangChain4jAutoConfiguration { @Lazy @ConditionalOnMissingBean @Primary - EmbeddingModel embeddingModel(LangChain4jProperties properties) { + EmbeddingModel embeddingModel(S2LangChain4jProperties properties) { - if (properties.getEmbeddingModel() == null || !Arrays.stream(ModelProvider.values()) - .anyMatch(provider -> provider.equals(properties.getEmbeddingModel().getProvider()))) { - return new AllMiniLmL6V2EmbeddingModel(); + if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) { + throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n" + + "langchain4j.embedding-model.provider = openai\n" + + "langchain4j.embedding-model.openai.api-key = sk-...\n"); } switch (properties.getEmbeddingModel().getProvider()) { @@ -243,15 +244,23 @@ public class S2LangChain4jAutoConfiguration { .logRequests(localAi.getLogRequests()) .logResponses(localAi.getLogResponses()) .build(); + case IN_PROCESS: + InProcess inProcess = properties.getEmbeddingModel().getInProcess(); + if (isNullOrBlank(inProcess.getModelPath())) { + return new AllMiniLmL6V2EmbeddingModel(); + } + return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath()); + default: - return new AllMiniLmL6V2EmbeddingModel(); + throw illegalConfiguration("Unsupported embedding model provider: %s", + properties.getEmbeddingModel().getProvider()); } } @Bean @Lazy @ConditionalOnMissingBean - ModerationModel moderationModel(LangChain4jProperties properties) { + ModerationModel moderationModel(S2LangChain4jProperties properties) { if (properties.getModerationModel() == null) { throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n" + "langchain4j.moderation-model.provider = openai\n" diff --git a/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jProperties.java b/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jProperties.java new file mode 100644 index 000000000..b6186899d --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jProperties.java @@ -0,0 +1,49 @@ +package dev.langchain4j; + +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = "s2.langchain4j") +public class S2LangChain4jProperties { + + @NestedConfigurationProperty + private ChatModel chatModel; + @NestedConfigurationProperty + private LanguageModel languageModel; + @NestedConfigurationProperty + private S2EmbeddingModel embeddingModel; + @NestedConfigurationProperty + private ModerationModel moderationModel; + + public ChatModel getChatModel() { + return chatModel; + } + + public void setChatModel(ChatModel chatModel) { + this.chatModel = chatModel; + } + + public LanguageModel getLanguageModel() { + return languageModel; + } + + public void setLanguageModel(LanguageModel languageModel) { + this.languageModel = languageModel; + } + + public S2EmbeddingModel getEmbeddingModel() { + return embeddingModel; + } + + public void setEmbeddingModel(S2EmbeddingModel s2EmbeddingModel) { + this.embeddingModel = s2EmbeddingModel; + } + + public ModerationModel getModerationModel() { + return moderationModel; + } + + public void setModerationModel(ModerationModel moderationModel) { + this.moderationModel = moderationModel; + } +} diff --git a/launchers/common/src/main/java/dev/langchain4j/S2ModelProvider.java b/launchers/common/src/main/java/dev/langchain4j/S2ModelProvider.java new file mode 100644 index 000000000..606e6a237 --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/S2ModelProvider.java @@ -0,0 +1,9 @@ +package dev.langchain4j; + +enum S2ModelProvider { + + OPEN_AI, + HUGGING_FACE, + LOCAL_AI, + IN_PROCESS +} \ No newline at end of file diff --git a/launchers/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java b/launchers/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java new file mode 100644 index 000000000..73bdc61dc --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java @@ -0,0 +1,61 @@ +package dev.langchain4j.model.embedding; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import org.apache.commons.lang3.StringUtils; + +/** + * An embedding model that runs within your Java application's process. + * Any BERT-based model (e.g., from HuggingFace) can be used, as long as it is in ONNX format. + * Information on how to convert models into ONNX format can be found here. + * Many models already converted to ONNX format are available here. + * Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel. + */ +public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { + + private final OnnxBertBiEncoder model; + + /** + * @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx"). + */ + public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) { + URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt"); + if (StringUtils.isNotBlank(vocabularyPath)) { + try { + resource = Paths.get(vocabularyPath).toUri().toURL(); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } + } + model = loadFromFileSystem(Paths.get(pathToModel), resource); + } + + /** + * @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx"). + */ + public S2OnnxEmbeddingModel(String pathToModel) { + this(pathToModel, null); + } + + @Override + protected OnnxBertBiEncoder model() { + return model; + } + + static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) { + try { + return new OnnxBertBiEncoder( + Files.newInputStream(pathToModel), + vocabularyFile, + PoolingMode.MEAN + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index 065895a83..24cc868b9 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -43,31 +43,37 @@ functionCall: url: http://127.0.0.1:9092 #langchain4j config -langchain4j: - #1.chat-model - chat-model: - provider: open_ai - openai: - api-key: api_key - model-name: gpt-3.5-turbo - temperature: 0.0 - timeout: PT60S - #2.embedding-model - #2.1 in_memory(default) - #2.2 open_ai -# embedding-model: -# provider: open_ai -# openai: -# api-key: api_key -# modelName: all-minilm-l6-v2.onnx +s2: + langchain4j: + #1.chat-model + chat-model: + provider: open_ai + openai: + api-key: api_key + model-name: gpt-3.5-turbo + temperature: 0.0 + timeout: PT60S + #2.embedding-model + #2.1 in_memory(default) + embedding-model: + provider: in_process +# inProcess: +# modelPath: /data/model.onnx +# vocabularyPath: /data/onnx_vocab.txt + #2.2 open_ai + # embedding-model: + # provider: open_ai + # openai: + # api-key: api_key + # modelName: all-minilm-l6-v2.onnx - #2.2 hugging_face -# embedding-model: -# provider: hugging_face -# hugging-face: -# access-token: hg_access_token -# model-id: sentence-transformers/all-MiniLM-L6-v2 -# timeout: 1h + #2.2 hugging_face + # embedding-model: + # provider: hugging_face + # hugging-face: + # access-token: hg_access_token + # model-id: sentence-transformers/all-MiniLM-L6-v2 + # timeout: 1h #langchain4j log logging: diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index bd3c0ed9b..2ad66636e 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -35,27 +35,38 @@ mybatis: #langchain4j config -langchain4j: - #1.chat-model - chat-model: - provider: open_ai - openai: - api-key: api_key - model-name: gpt-3.5-turbo - temperature: 0.0 - timeout: PT60S - #2.embedding-model -# embedding-model: -# hugging-face: -# access-token: hg_access_token -# model-id: sentence-transformers/all-MiniLM-L6-v2 -# timeout: 1h +s2: + langchain4j: + #1.chat-model + chat-model: + provider: open_ai + openai: + api-key: api_key + model-name: gpt-3.5-turbo + temperature: 0.0 + timeout: PT60S + #2.embedding-model + #2.1 in_memory(default) + embedding-model: + provider: in_process + # inProcess: + # modelPath: /data/model.onnx + # vocabularyPath: /data/onnx_vocab.txt + #2.2 open_ai + # embedding-model: + # provider: open_ai + # openai: + # api-key: api_key + # modelName: all-minilm-l6-v2.onnx + + #2.2 hugging_face + # embedding-model: + # provider: hugging_face + # hugging-face: + # access-token: hg_access_token + # model-id: sentence-transformers/all-MiniLM-L6-v2 + # timeout: 1h -# embedding-model: -# provider: open_ai -# openai: -# api-key: api_key -# modelName: all-minilm-l6-v2.onnx #langchain4j log