[improvement](chat) Add an in_process provider and support offline loading of local embedding models. (#505)

This commit is contained in:
lexluo09
2023-12-14 14:16:03 +08:00
committed by GitHub
parent 169262cc62
commit 287a6561ff
9 changed files with 292 additions and 55 deletions

View File

@@ -186,6 +186,10 @@
<artifactId>log4j-api</artifactId>
<version>${apache.log4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
</dependency>
</dependencies>

View File

@@ -0,0 +1,30 @@
package dev.langchain4j;
class InProcess {
/***
* the model local path
*/
private String modelPath;
/***
* the model's vocabulary local path
*/
private String vocabularyPath;
public String getModelPath() {
return modelPath;
}
public void setModelPath(String modelPath) {
this.modelPath = modelPath;
}
public String getVocabularyPath() {
return vocabularyPath;
}
public void setVocabularyPath(String vocabularyPath) {
this.vocabularyPath = vocabularyPath;
}
}

View File

@@ -0,0 +1,58 @@
package dev.langchain4j;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
class S2EmbeddingModel {
@NestedConfigurationProperty
private S2ModelProvider provider;
@NestedConfigurationProperty
private OpenAi openAi;
@NestedConfigurationProperty
private HuggingFace huggingFace;
@NestedConfigurationProperty
private LocalAi localAi;
@NestedConfigurationProperty
private InProcess inProcess;
public S2ModelProvider getProvider() {
return provider;
}
public void setProvider(S2ModelProvider provider) {
this.provider = provider;
}
public OpenAi getOpenAi() {
return openAi;
}
public void setOpenAi(OpenAi openAi) {
this.openAi = openAi;
}
public HuggingFace getHuggingFace() {
return huggingFace;
}
public void setHuggingFace(HuggingFace huggingFace) {
this.huggingFace = huggingFace;
}
public LocalAi getLocalAi() {
return localAi;
}
public void setLocalAi(LocalAi localAi) {
this.localAi = localAi;
}
public InProcess getInProcess() {
return inProcess;
}
public void setInProcess(InProcess inProcess) {
this.inProcess = inProcess;
}
}

View File

@@ -7,6 +7,7 @@ import static dev.langchain4j.internal.Utils.isNullOrBlank;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
@@ -19,7 +20,6 @@ import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiLanguageModel;
import dev.langchain4j.model.openai.OpenAiModerationModel;
import java.util.Arrays;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -29,16 +29,16 @@ import org.springframework.context.annotation.Lazy;
import org.springframework.context.annotation.Primary;
@Configuration
@EnableConfigurationProperties(LangChain4jProperties.class)
@EnableConfigurationProperties(S2LangChain4jProperties.class)
public class S2LangChain4jAutoConfiguration {
@Autowired
private LangChain4jProperties properties;
private S2LangChain4jProperties properties;
@Bean
@Lazy
@ConditionalOnMissingBean
ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
ChatLanguageModel chatLanguageModel(S2LangChain4jProperties properties) {
if (properties.getChatModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n"
+ "langchain4j.chat-model.provider = openai\n"
@@ -113,7 +113,7 @@ public class S2LangChain4jAutoConfiguration {
@Bean
@Lazy
@ConditionalOnMissingBean
LanguageModel languageModel(LangChain4jProperties properties) {
LanguageModel languageModel(S2LangChain4jProperties properties) {
if (properties.getLanguageModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n"
+ "langchain4j.language-model.provider = openai\n"
@@ -187,11 +187,12 @@ public class S2LangChain4jAutoConfiguration {
@Lazy
@ConditionalOnMissingBean
@Primary
EmbeddingModel embeddingModel(LangChain4jProperties properties) {
EmbeddingModel embeddingModel(S2LangChain4jProperties properties) {
if (properties.getEmbeddingModel() == null || !Arrays.stream(ModelProvider.values())
.anyMatch(provider -> provider.equals(properties.getEmbeddingModel().getProvider()))) {
return new AllMiniLmL6V2EmbeddingModel();
if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n"
+ "langchain4j.embedding-model.provider = openai\n"
+ "langchain4j.embedding-model.openai.api-key = sk-...\n");
}
switch (properties.getEmbeddingModel().getProvider()) {
@@ -243,15 +244,23 @@ public class S2LangChain4jAutoConfiguration {
.logRequests(localAi.getLogRequests())
.logResponses(localAi.getLogResponses())
.build();
case IN_PROCESS:
InProcess inProcess = properties.getEmbeddingModel().getInProcess();
if (isNullOrBlank(inProcess.getModelPath())) {
return new AllMiniLmL6V2EmbeddingModel();
}
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());
default:
return new AllMiniLmL6V2EmbeddingModel();
throw illegalConfiguration("Unsupported embedding model provider: %s",
properties.getEmbeddingModel().getProvider());
}
}
@Bean
@Lazy
@ConditionalOnMissingBean
ModerationModel moderationModel(LangChain4jProperties properties) {
ModerationModel moderationModel(S2LangChain4jProperties properties) {
if (properties.getModerationModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n"
+ "langchain4j.moderation-model.provider = openai\n"

View File

@@ -0,0 +1,49 @@
package dev.langchain4j;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
@ConfigurationProperties(prefix = "s2.langchain4j")
public class S2LangChain4jProperties {
@NestedConfigurationProperty
private ChatModel chatModel;
@NestedConfigurationProperty
private LanguageModel languageModel;
@NestedConfigurationProperty
private S2EmbeddingModel embeddingModel;
@NestedConfigurationProperty
private ModerationModel moderationModel;
public ChatModel getChatModel() {
return chatModel;
}
public void setChatModel(ChatModel chatModel) {
this.chatModel = chatModel;
}
public LanguageModel getLanguageModel() {
return languageModel;
}
public void setLanguageModel(LanguageModel languageModel) {
this.languageModel = languageModel;
}
public S2EmbeddingModel getEmbeddingModel() {
return embeddingModel;
}
public void setEmbeddingModel(S2EmbeddingModel s2EmbeddingModel) {
this.embeddingModel = s2EmbeddingModel;
}
public ModerationModel getModerationModel() {
return moderationModel;
}
public void setModerationModel(ModerationModel moderationModel) {
this.moderationModel = moderationModel;
}
}

View File

@@ -0,0 +1,9 @@
package dev.langchain4j;
enum S2ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_PROCESS
}

View File

@@ -0,0 +1,61 @@
package dev.langchain4j.model.embedding;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.apache.commons.lang3.StringUtils;
/**
* An embedding model that runs within your Java application's process.
* Any BERT-based model (e.g., from HuggingFace) can be used, as long as it is in ONNX format.
* Information on how to convert models into ONNX format can be found <a
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
* Many models already converted to ONNX format are available <a href="https://huggingface.co/Xenova">here</a>.
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
*/
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
private final OnnxBertBiEncoder model;
/**
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
*/
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
if (StringUtils.isNotBlank(vocabularyPath)) {
try {
resource = Paths.get(vocabularyPath).toUri().toURL();
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
model = loadFromFileSystem(Paths.get(pathToModel), resource);
}
/**
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
*/
public S2OnnxEmbeddingModel(String pathToModel) {
this(pathToModel, null);
}
@Override
protected OnnxBertBiEncoder model() {
return model;
}
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
try {
return new OnnxBertBiEncoder(
Files.newInputStream(pathToModel),
vocabularyFile,
PoolingMode.MEAN
);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -43,31 +43,37 @@ functionCall:
url: http://127.0.0.1:9092
#langchain4j config
langchain4j:
#1.chat-model
chat-model:
provider: open_ai
openai:
api-key: api_key
model-name: gpt-3.5-turbo
temperature: 0.0
timeout: PT60S
#2.embedding-model
#2.1 in_memory(default)
#2.2 open_ai
# embedding-model:
# provider: open_ai
# openai:
# api-key: api_key
# modelName: all-minilm-l6-v2.onnx
s2:
langchain4j:
#1.chat-model
chat-model:
provider: open_ai
openai:
api-key: api_key
model-name: gpt-3.5-turbo
temperature: 0.0
timeout: PT60S
#2.embedding-model
#2.1 in_memory(default)
embedding-model:
provider: in_process
# inProcess:
# modelPath: /data/model.onnx
# vocabularyPath: /data/onnx_vocab.txt
#2.2 open_ai
# embedding-model:
# provider: open_ai
# openai:
# api-key: api_key
# modelName: all-minilm-l6-v2.onnx
#2.2 hugging_face
# embedding-model:
# provider: hugging_face
# hugging-face:
# access-token: hg_access_token
# model-id: sentence-transformers/all-MiniLM-L6-v2
# timeout: 1h
#2.2 hugging_face
# embedding-model:
# provider: hugging_face
# hugging-face:
# access-token: hg_access_token
# model-id: sentence-transformers/all-MiniLM-L6-v2
# timeout: 1h
#langchain4j log
logging:

View File

@@ -35,27 +35,38 @@ mybatis:
#langchain4j config
langchain4j:
#1.chat-model
chat-model:
provider: open_ai
openai:
api-key: api_key
model-name: gpt-3.5-turbo
temperature: 0.0
timeout: PT60S
#2.embedding-model
# embedding-model:
# hugging-face:
# access-token: hg_access_token
# model-id: sentence-transformers/all-MiniLM-L6-v2
# timeout: 1h
s2:
langchain4j:
#1.chat-model
chat-model:
provider: open_ai
openai:
api-key: api_key
model-name: gpt-3.5-turbo
temperature: 0.0
timeout: PT60S
#2.embedding-model
#2.1 in_memory(default)
embedding-model:
provider: in_process
# inProcess:
# modelPath: /data/model.onnx
# vocabularyPath: /data/onnx_vocab.txt
#2.2 open_ai
# embedding-model:
# provider: open_ai
# openai:
# api-key: api_key
# modelName: all-minilm-l6-v2.onnx
#2.2 hugging_face
# embedding-model:
# provider: hugging_face
# hugging-face:
# access-token: hg_access_token
# model-id: sentence-transformers/all-MiniLM-L6-v2
# timeout: 1h
# embedding-model:
# provider: open_ai
# openai:
# api-key: api_key
# modelName: all-minilm-l6-v2.onnx
#langchain4j log