mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-18 08:17:18 +00:00
[improvement](chat) Add an in_process provider and support offline loading of local embedding models. (#505)
This commit is contained in:
@@ -186,6 +186,10 @@
|
|||||||
<artifactId>log4j-api</artifactId>
|
<artifactId>log4j-api</artifactId>
|
||||||
<version>${apache.log4j.version}</version>
|
<version>${apache.log4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-embeddings</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package dev.langchain4j;
|
||||||
|
|
||||||
|
class InProcess {
|
||||||
|
|
||||||
|
/***
|
||||||
|
* the model local path
|
||||||
|
*/
|
||||||
|
private String modelPath;
|
||||||
|
|
||||||
|
/***
|
||||||
|
* the model's vocabulary local path
|
||||||
|
*/
|
||||||
|
private String vocabularyPath;
|
||||||
|
|
||||||
|
public String getModelPath() {
|
||||||
|
return modelPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setModelPath(String modelPath) {
|
||||||
|
this.modelPath = modelPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getVocabularyPath() {
|
||||||
|
return vocabularyPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setVocabularyPath(String vocabularyPath) {
|
||||||
|
this.vocabularyPath = vocabularyPath;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package dev.langchain4j;
|
||||||
|
|
||||||
|
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||||
|
|
||||||
|
class S2EmbeddingModel {
|
||||||
|
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private S2ModelProvider provider;
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private OpenAi openAi;
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private HuggingFace huggingFace;
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private LocalAi localAi;
|
||||||
|
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private InProcess inProcess;
|
||||||
|
|
||||||
|
public S2ModelProvider getProvider() {
|
||||||
|
return provider;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setProvider(S2ModelProvider provider) {
|
||||||
|
this.provider = provider;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OpenAi getOpenAi() {
|
||||||
|
return openAi;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setOpenAi(OpenAi openAi) {
|
||||||
|
this.openAi = openAi;
|
||||||
|
}
|
||||||
|
|
||||||
|
public HuggingFace getHuggingFace() {
|
||||||
|
return huggingFace;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setHuggingFace(HuggingFace huggingFace) {
|
||||||
|
this.huggingFace = huggingFace;
|
||||||
|
}
|
||||||
|
|
||||||
|
public LocalAi getLocalAi() {
|
||||||
|
return localAi;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLocalAi(LocalAi localAi) {
|
||||||
|
this.localAi = localAi;
|
||||||
|
}
|
||||||
|
|
||||||
|
public InProcess getInProcess() {
|
||||||
|
return inProcess;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setInProcess(InProcess inProcess) {
|
||||||
|
this.inProcess = inProcess;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
|||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||||
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
|
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
|
||||||
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
|
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
|
||||||
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
|
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
|
||||||
@@ -19,7 +20,6 @@ import dev.langchain4j.model.openai.OpenAiChatModel;
|
|||||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiLanguageModel;
|
import dev.langchain4j.model.openai.OpenAiLanguageModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiModerationModel;
|
import dev.langchain4j.model.openai.OpenAiModerationModel;
|
||||||
import java.util.Arrays;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
@@ -29,16 +29,16 @@ import org.springframework.context.annotation.Lazy;
|
|||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@EnableConfigurationProperties(LangChain4jProperties.class)
|
@EnableConfigurationProperties(S2LangChain4jProperties.class)
|
||||||
public class S2LangChain4jAutoConfiguration {
|
public class S2LangChain4jAutoConfiguration {
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private LangChain4jProperties properties;
|
private S2LangChain4jProperties properties;
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
@Lazy
|
@Lazy
|
||||||
@ConditionalOnMissingBean
|
@ConditionalOnMissingBean
|
||||||
ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
|
ChatLanguageModel chatLanguageModel(S2LangChain4jProperties properties) {
|
||||||
if (properties.getChatModel() == null) {
|
if (properties.getChatModel() == null) {
|
||||||
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n"
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n"
|
||||||
+ "langchain4j.chat-model.provider = openai\n"
|
+ "langchain4j.chat-model.provider = openai\n"
|
||||||
@@ -113,7 +113,7 @@ public class S2LangChain4jAutoConfiguration {
|
|||||||
@Bean
|
@Bean
|
||||||
@Lazy
|
@Lazy
|
||||||
@ConditionalOnMissingBean
|
@ConditionalOnMissingBean
|
||||||
LanguageModel languageModel(LangChain4jProperties properties) {
|
LanguageModel languageModel(S2LangChain4jProperties properties) {
|
||||||
if (properties.getLanguageModel() == null) {
|
if (properties.getLanguageModel() == null) {
|
||||||
throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n"
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n"
|
||||||
+ "langchain4j.language-model.provider = openai\n"
|
+ "langchain4j.language-model.provider = openai\n"
|
||||||
@@ -187,11 +187,12 @@ public class S2LangChain4jAutoConfiguration {
|
|||||||
@Lazy
|
@Lazy
|
||||||
@ConditionalOnMissingBean
|
@ConditionalOnMissingBean
|
||||||
@Primary
|
@Primary
|
||||||
EmbeddingModel embeddingModel(LangChain4jProperties properties) {
|
EmbeddingModel embeddingModel(S2LangChain4jProperties properties) {
|
||||||
|
|
||||||
if (properties.getEmbeddingModel() == null || !Arrays.stream(ModelProvider.values())
|
if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) {
|
||||||
.anyMatch(provider -> provider.equals(properties.getEmbeddingModel().getProvider()))) {
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n"
|
||||||
return new AllMiniLmL6V2EmbeddingModel();
|
+ "langchain4j.embedding-model.provider = openai\n"
|
||||||
|
+ "langchain4j.embedding-model.openai.api-key = sk-...\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (properties.getEmbeddingModel().getProvider()) {
|
switch (properties.getEmbeddingModel().getProvider()) {
|
||||||
@@ -243,15 +244,23 @@ public class S2LangChain4jAutoConfiguration {
|
|||||||
.logRequests(localAi.getLogRequests())
|
.logRequests(localAi.getLogRequests())
|
||||||
.logResponses(localAi.getLogResponses())
|
.logResponses(localAi.getLogResponses())
|
||||||
.build();
|
.build();
|
||||||
default:
|
case IN_PROCESS:
|
||||||
|
InProcess inProcess = properties.getEmbeddingModel().getInProcess();
|
||||||
|
if (isNullOrBlank(inProcess.getModelPath())) {
|
||||||
return new AllMiniLmL6V2EmbeddingModel();
|
return new AllMiniLmL6V2EmbeddingModel();
|
||||||
}
|
}
|
||||||
|
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw illegalConfiguration("Unsupported embedding model provider: %s",
|
||||||
|
properties.getEmbeddingModel().getProvider());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
@Lazy
|
@Lazy
|
||||||
@ConditionalOnMissingBean
|
@ConditionalOnMissingBean
|
||||||
ModerationModel moderationModel(LangChain4jProperties properties) {
|
ModerationModel moderationModel(S2LangChain4jProperties properties) {
|
||||||
if (properties.getModerationModel() == null) {
|
if (properties.getModerationModel() == null) {
|
||||||
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n"
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n"
|
||||||
+ "langchain4j.moderation-model.provider = openai\n"
|
+ "langchain4j.moderation-model.provider = openai\n"
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package dev.langchain4j;
|
||||||
|
|
||||||
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||||
|
|
||||||
|
@ConfigurationProperties(prefix = "s2.langchain4j")
|
||||||
|
public class S2LangChain4jProperties {
|
||||||
|
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private ChatModel chatModel;
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private LanguageModel languageModel;
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private S2EmbeddingModel embeddingModel;
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
private ModerationModel moderationModel;
|
||||||
|
|
||||||
|
public ChatModel getChatModel() {
|
||||||
|
return chatModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setChatModel(ChatModel chatModel) {
|
||||||
|
this.chatModel = chatModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public LanguageModel getLanguageModel() {
|
||||||
|
return languageModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLanguageModel(LanguageModel languageModel) {
|
||||||
|
this.languageModel = languageModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public S2EmbeddingModel getEmbeddingModel() {
|
||||||
|
return embeddingModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEmbeddingModel(S2EmbeddingModel s2EmbeddingModel) {
|
||||||
|
this.embeddingModel = s2EmbeddingModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ModerationModel getModerationModel() {
|
||||||
|
return moderationModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setModerationModel(ModerationModel moderationModel) {
|
||||||
|
this.moderationModel = moderationModel;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package dev.langchain4j;
|
||||||
|
|
||||||
|
enum S2ModelProvider {
|
||||||
|
|
||||||
|
OPEN_AI,
|
||||||
|
HUGGING_FACE,
|
||||||
|
LOCAL_AI,
|
||||||
|
IN_PROCESS
|
||||||
|
}
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package dev.langchain4j.model.embedding;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.MalformedURLException;
|
||||||
|
import java.net.URL;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An embedding model that runs within your Java application's process.
|
||||||
|
* Any BERT-based model (e.g., from HuggingFace) can be used, as long as it is in ONNX format.
|
||||||
|
* Information on how to convert models into ONNX format can be found <a
|
||||||
|
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
|
||||||
|
* Many models already converted to ONNX format are available <a href="https://huggingface.co/Xenova">here</a>.
|
||||||
|
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
||||||
|
*/
|
||||||
|
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||||
|
|
||||||
|
private final OnnxBertBiEncoder model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
|
||||||
|
*/
|
||||||
|
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
||||||
|
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
|
||||||
|
if (StringUtils.isNotBlank(vocabularyPath)) {
|
||||||
|
try {
|
||||||
|
resource = Paths.get(vocabularyPath).toUri().toURL();
|
||||||
|
} catch (MalformedURLException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model = loadFromFileSystem(Paths.get(pathToModel), resource);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param pathToModel The path to the .onnx model file (e.g., "/home/me/model.onnx").
|
||||||
|
*/
|
||||||
|
public S2OnnxEmbeddingModel(String pathToModel) {
|
||||||
|
this(pathToModel, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected OnnxBertBiEncoder model() {
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
||||||
|
try {
|
||||||
|
return new OnnxBertBiEncoder(
|
||||||
|
Files.newInputStream(pathToModel),
|
||||||
|
vocabularyFile,
|
||||||
|
PoolingMode.MEAN
|
||||||
|
);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -43,7 +43,8 @@ functionCall:
|
|||||||
url: http://127.0.0.1:9092
|
url: http://127.0.0.1:9092
|
||||||
|
|
||||||
#langchain4j config
|
#langchain4j config
|
||||||
langchain4j:
|
s2:
|
||||||
|
langchain4j:
|
||||||
#1.chat-model
|
#1.chat-model
|
||||||
chat-model:
|
chat-model:
|
||||||
provider: open_ai
|
provider: open_ai
|
||||||
@@ -54,20 +55,25 @@ langchain4j:
|
|||||||
timeout: PT60S
|
timeout: PT60S
|
||||||
#2.embedding-model
|
#2.embedding-model
|
||||||
#2.1 in_memory(default)
|
#2.1 in_memory(default)
|
||||||
|
embedding-model:
|
||||||
|
provider: in_process
|
||||||
|
# inProcess:
|
||||||
|
# modelPath: /data/model.onnx
|
||||||
|
# vocabularyPath: /data/onnx_vocab.txt
|
||||||
#2.2 open_ai
|
#2.2 open_ai
|
||||||
# embedding-model:
|
# embedding-model:
|
||||||
# provider: open_ai
|
# provider: open_ai
|
||||||
# openai:
|
# openai:
|
||||||
# api-key: api_key
|
# api-key: api_key
|
||||||
# modelName: all-minilm-l6-v2.onnx
|
# modelName: all-minilm-l6-v2.onnx
|
||||||
|
|
||||||
#2.2 hugging_face
|
#2.2 hugging_face
|
||||||
# embedding-model:
|
# embedding-model:
|
||||||
# provider: hugging_face
|
# provider: hugging_face
|
||||||
# hugging-face:
|
# hugging-face:
|
||||||
# access-token: hg_access_token
|
# access-token: hg_access_token
|
||||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||||
# timeout: 1h
|
# timeout: 1h
|
||||||
|
|
||||||
#langchain4j log
|
#langchain4j log
|
||||||
logging:
|
logging:
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ mybatis:
|
|||||||
|
|
||||||
|
|
||||||
#langchain4j config
|
#langchain4j config
|
||||||
langchain4j:
|
s2:
|
||||||
|
langchain4j:
|
||||||
#1.chat-model
|
#1.chat-model
|
||||||
chat-model:
|
chat-model:
|
||||||
provider: open_ai
|
provider: open_ai
|
||||||
@@ -45,17 +46,27 @@ langchain4j:
|
|||||||
temperature: 0.0
|
temperature: 0.0
|
||||||
timeout: PT60S
|
timeout: PT60S
|
||||||
#2.embedding-model
|
#2.embedding-model
|
||||||
# embedding-model:
|
#2.1 in_memory(default)
|
||||||
# hugging-face:
|
embedding-model:
|
||||||
# access-token: hg_access_token
|
provider: in_process
|
||||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
# inProcess:
|
||||||
# timeout: 1h
|
# modelPath: /data/model.onnx
|
||||||
|
# vocabularyPath: /data/onnx_vocab.txt
|
||||||
|
#2.2 open_ai
|
||||||
|
# embedding-model:
|
||||||
|
# provider: open_ai
|
||||||
|
# openai:
|
||||||
|
# api-key: api_key
|
||||||
|
# modelName: all-minilm-l6-v2.onnx
|
||||||
|
|
||||||
|
#2.2 hugging_face
|
||||||
|
# embedding-model:
|
||||||
|
# provider: hugging_face
|
||||||
|
# hugging-face:
|
||||||
|
# access-token: hg_access_token
|
||||||
|
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
# timeout: 1h
|
||||||
|
|
||||||
# embedding-model:
|
|
||||||
# provider: open_ai
|
|
||||||
# openai:
|
|
||||||
# api-key: api_key
|
|
||||||
# modelName: all-minilm-l6-v2.onnx
|
|
||||||
|
|
||||||
|
|
||||||
#langchain4j log
|
#langchain4j log
|
||||||
|
|||||||
Reference in New Issue
Block a user