mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 22:46:49 +00:00
[improvement](chat) Add an in_process provider and support offline loading of local embedding models. (#505)
This commit is contained in:
@@ -186,6 +186,10 @@
|
||||
<artifactId>log4j-api</artifactId>
|
||||
<version>${apache.log4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package dev.langchain4j;
|
||||
|
||||
class InProcess {
|
||||
|
||||
/***
|
||||
* the model local path
|
||||
*/
|
||||
private String modelPath;
|
||||
|
||||
/***
|
||||
* the model's vocabulary local path
|
||||
*/
|
||||
private String vocabularyPath;
|
||||
|
||||
public String getModelPath() {
|
||||
return modelPath;
|
||||
}
|
||||
|
||||
public void setModelPath(String modelPath) {
|
||||
this.modelPath = modelPath;
|
||||
}
|
||||
|
||||
public String getVocabularyPath() {
|
||||
return vocabularyPath;
|
||||
}
|
||||
|
||||
public void setVocabularyPath(String vocabularyPath) {
|
||||
this.vocabularyPath = vocabularyPath;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package dev.langchain4j;
|
||||
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
class S2EmbeddingModel {
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private S2ModelProvider provider;
|
||||
@NestedConfigurationProperty
|
||||
private OpenAi openAi;
|
||||
@NestedConfigurationProperty
|
||||
private HuggingFace huggingFace;
|
||||
@NestedConfigurationProperty
|
||||
private LocalAi localAi;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private InProcess inProcess;
|
||||
|
||||
public S2ModelProvider getProvider() {
|
||||
return provider;
|
||||
}
|
||||
|
||||
public void setProvider(S2ModelProvider provider) {
|
||||
this.provider = provider;
|
||||
}
|
||||
|
||||
public OpenAi getOpenAi() {
|
||||
return openAi;
|
||||
}
|
||||
|
||||
public void setOpenAi(OpenAi openAi) {
|
||||
this.openAi = openAi;
|
||||
}
|
||||
|
||||
public HuggingFace getHuggingFace() {
|
||||
return huggingFace;
|
||||
}
|
||||
|
||||
public void setHuggingFace(HuggingFace huggingFace) {
|
||||
this.huggingFace = huggingFace;
|
||||
}
|
||||
|
||||
public LocalAi getLocalAi() {
|
||||
return localAi;
|
||||
}
|
||||
|
||||
public void setLocalAi(LocalAi localAi) {
|
||||
this.localAi = localAi;
|
||||
}
|
||||
|
||||
public InProcess getInProcess() {
|
||||
return inProcess;
|
||||
}
|
||||
|
||||
public void setInProcess(InProcess inProcess) {
|
||||
this.inProcess = inProcess;
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
|
||||
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
|
||||
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
|
||||
@@ -19,7 +20,6 @@ import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiModerationModel;
|
||||
import java.util.Arrays;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
@@ -29,16 +29,16 @@ import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(LangChain4jProperties.class)
|
||||
@EnableConfigurationProperties(S2LangChain4jProperties.class)
|
||||
public class S2LangChain4jAutoConfiguration {
|
||||
|
||||
@Autowired
|
||||
private LangChain4jProperties properties;
|
||||
private S2LangChain4jProperties properties;
|
||||
|
||||
@Bean
|
||||
@Lazy
|
||||
@ConditionalOnMissingBean
|
||||
ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
|
||||
ChatLanguageModel chatLanguageModel(S2LangChain4jProperties properties) {
|
||||
if (properties.getChatModel() == null) {
|
||||
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n"
|
||||
+ "langchain4j.chat-model.provider = openai\n"
|
||||
@@ -113,7 +113,7 @@ public class S2LangChain4jAutoConfiguration {
|
||||
@Bean
|
||||
@Lazy
|
||||
@ConditionalOnMissingBean
|
||||
LanguageModel languageModel(LangChain4jProperties properties) {
|
||||
LanguageModel languageModel(S2LangChain4jProperties properties) {
|
||||
if (properties.getLanguageModel() == null) {
|
||||
throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n"
|
||||
+ "langchain4j.language-model.provider = openai\n"
|
||||
@@ -187,11 +187,12 @@ public class S2LangChain4jAutoConfiguration {
|
||||
@Lazy
|
||||
@ConditionalOnMissingBean
|
||||
@Primary
|
||||
EmbeddingModel embeddingModel(LangChain4jProperties properties) {
|
||||
EmbeddingModel embeddingModel(S2LangChain4jProperties properties) {
|
||||
|
||||
if (properties.getEmbeddingModel() == null || !Arrays.stream(ModelProvider.values())
|
||||
.anyMatch(provider -> provider.equals(properties.getEmbeddingModel().getProvider()))) {
|
||||
return new AllMiniLmL6V2EmbeddingModel();
|
||||
if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) {
|
||||
throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n"
|
||||
+ "langchain4j.embedding-model.provider = openai\n"
|
||||
+ "langchain4j.embedding-model.openai.api-key = sk-...\n");
|
||||
}
|
||||
|
||||
switch (properties.getEmbeddingModel().getProvider()) {
|
||||
@@ -243,15 +244,23 @@ public class S2LangChain4jAutoConfiguration {
|
||||
.logRequests(localAi.getLogRequests())
|
||||
.logResponses(localAi.getLogResponses())
|
||||
.build();
|
||||
case IN_PROCESS:
|
||||
InProcess inProcess = properties.getEmbeddingModel().getInProcess();
|
||||
if (isNullOrBlank(inProcess.getModelPath())) {
|
||||
return new AllMiniLmL6V2EmbeddingModel();
|
||||
}
|
||||
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());
|
||||
|
||||
default:
|
||||
return new AllMiniLmL6V2EmbeddingModel();
|
||||
throw illegalConfiguration("Unsupported embedding model provider: %s",
|
||||
properties.getEmbeddingModel().getProvider());
|
||||
}
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Lazy
|
||||
@ConditionalOnMissingBean
|
||||
ModerationModel moderationModel(LangChain4jProperties properties) {
|
||||
ModerationModel moderationModel(S2LangChain4jProperties properties) {
|
||||
if (properties.getModerationModel() == null) {
|
||||
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n"
|
||||
+ "langchain4j.moderation-model.provider = openai\n"
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package dev.langchain4j;
|
||||
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
@ConfigurationProperties(prefix = "s2.langchain4j")
|
||||
public class S2LangChain4jProperties {
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private ChatModel chatModel;
|
||||
@NestedConfigurationProperty
|
||||
private LanguageModel languageModel;
|
||||
@NestedConfigurationProperty
|
||||
private S2EmbeddingModel embeddingModel;
|
||||
@NestedConfigurationProperty
|
||||
private ModerationModel moderationModel;
|
||||
|
||||
public ChatModel getChatModel() {
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
public void setChatModel(ChatModel chatModel) {
|
||||
this.chatModel = chatModel;
|
||||
}
|
||||
|
||||
public LanguageModel getLanguageModel() {
|
||||
return languageModel;
|
||||
}
|
||||
|
||||
public void setLanguageModel(LanguageModel languageModel) {
|
||||
this.languageModel = languageModel;
|
||||
}
|
||||
|
||||
public S2EmbeddingModel getEmbeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
public void setEmbeddingModel(S2EmbeddingModel s2EmbeddingModel) {
|
||||
this.embeddingModel = s2EmbeddingModel;
|
||||
}
|
||||
|
||||
public ModerationModel getModerationModel() {
|
||||
return moderationModel;
|
||||
}
|
||||
|
||||
public void setModerationModel(ModerationModel moderationModel) {
|
||||
this.moderationModel = moderationModel;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package dev.langchain4j;
|
||||
|
||||
enum S2ModelProvider {
|
||||
|
||||
OPEN_AI,
|
||||
HUGGING_FACE,
|
||||
LOCAL_AI,
|
||||
IN_PROCESS
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* An embedding model that runs within your Java application's process.
|
||||
* Any BERT-based model (e.g., from HuggingFace) can be used, as long as it is in ONNX format.
|
||||
* Information on how to convert models into ONNX format can be found <a
|
||||
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
|
||||
* Many models already converted to ONNX format are available <a href="https://huggingface.co/Xenova">here</a>.
|
||||
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
||||
*/
|
||||
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
|
||||
private final OnnxBertBiEncoder model;
|
||||
|
||||
/**
|
||||
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
|
||||
*/
|
||||
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
||||
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
|
||||
if (StringUtils.isNotBlank(vocabularyPath)) {
|
||||
try {
|
||||
resource = Paths.get(vocabularyPath).toUri().toURL();
|
||||
} catch (MalformedURLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
model = loadFromFileSystem(Paths.get(pathToModel), resource);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
|
||||
*/
|
||||
public S2OnnxEmbeddingModel(String pathToModel) {
|
||||
this(pathToModel, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OnnxBertBiEncoder model() {
|
||||
return model;
|
||||
}
|
||||
|
||||
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
||||
try {
|
||||
return new OnnxBertBiEncoder(
|
||||
Files.newInputStream(pathToModel),
|
||||
vocabularyFile,
|
||||
PoolingMode.MEAN
|
||||
);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -43,31 +43,37 @@ functionCall:
|
||||
url: http://127.0.0.1:9092
|
||||
|
||||
#langchain4j config
|
||||
langchain4j:
|
||||
#1.chat-model
|
||||
chat-model:
|
||||
provider: open_ai
|
||||
openai:
|
||||
api-key: api_key
|
||||
model-name: gpt-3.5-turbo
|
||||
temperature: 0.0
|
||||
timeout: PT60S
|
||||
#2.embedding-model
|
||||
#2.1 in_memory(default)
|
||||
#2.2 open_ai
|
||||
# embedding-model:
|
||||
# provider: open_ai
|
||||
# openai:
|
||||
# api-key: api_key
|
||||
# modelName: all-minilm-l6-v2.onnx
|
||||
s2:
|
||||
langchain4j:
|
||||
#1.chat-model
|
||||
chat-model:
|
||||
provider: open_ai
|
||||
openai:
|
||||
api-key: api_key
|
||||
model-name: gpt-3.5-turbo
|
||||
temperature: 0.0
|
||||
timeout: PT60S
|
||||
#2.embedding-model
|
||||
#2.1 in_memory(default)
|
||||
embedding-model:
|
||||
provider: in_process
|
||||
# inProcess:
|
||||
# modelPath: /data/model.onnx
|
||||
# vocabularyPath: /data/onnx_vocab.txt
|
||||
#2.2 open_ai
|
||||
# embedding-model:
|
||||
# provider: open_ai
|
||||
# openai:
|
||||
# api-key: api_key
|
||||
# modelName: all-minilm-l6-v2.onnx
|
||||
|
||||
#2.2 hugging_face
|
||||
# embedding-model:
|
||||
# provider: hugging_face
|
||||
# hugging-face:
|
||||
# access-token: hg_access_token
|
||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||
# timeout: 1h
|
||||
#2.2 hugging_face
|
||||
# embedding-model:
|
||||
# provider: hugging_face
|
||||
# hugging-face:
|
||||
# access-token: hg_access_token
|
||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||
# timeout: 1h
|
||||
|
||||
#langchain4j log
|
||||
logging:
|
||||
|
||||
@@ -35,27 +35,38 @@ mybatis:
|
||||
|
||||
|
||||
#langchain4j config
|
||||
langchain4j:
|
||||
#1.chat-model
|
||||
chat-model:
|
||||
provider: open_ai
|
||||
openai:
|
||||
api-key: api_key
|
||||
model-name: gpt-3.5-turbo
|
||||
temperature: 0.0
|
||||
timeout: PT60S
|
||||
#2.embedding-model
|
||||
# embedding-model:
|
||||
# hugging-face:
|
||||
# access-token: hg_access_token
|
||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||
# timeout: 1h
|
||||
s2:
|
||||
langchain4j:
|
||||
#1.chat-model
|
||||
chat-model:
|
||||
provider: open_ai
|
||||
openai:
|
||||
api-key: api_key
|
||||
model-name: gpt-3.5-turbo
|
||||
temperature: 0.0
|
||||
timeout: PT60S
|
||||
#2.embedding-model
|
||||
#2.1 in_memory(default)
|
||||
embedding-model:
|
||||
provider: in_process
|
||||
# inProcess:
|
||||
# modelPath: /data/model.onnx
|
||||
# vocabularyPath: /data/onnx_vocab.txt
|
||||
#2.2 open_ai
|
||||
# embedding-model:
|
||||
# provider: open_ai
|
||||
# openai:
|
||||
# api-key: api_key
|
||||
# modelName: all-minilm-l6-v2.onnx
|
||||
|
||||
#2.2 hugging_face
|
||||
# embedding-model:
|
||||
# provider: hugging_face
|
||||
# hugging-face:
|
||||
# access-token: hg_access_token
|
||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||
# timeout: 1h
|
||||
|
||||
# embedding-model:
|
||||
# provider: open_ai
|
||||
# openai:
|
||||
# api-key: api_key
|
||||
# modelName: all-minilm-l6-v2.onnx
|
||||
|
||||
|
||||
#langchain4j log
|
||||
|
||||
Reference in New Issue
Block a user