mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(headless) Add support for the Ollama provider in the frontend and optimize the code (#1270)
This commit is contained in:
@@ -1,9 +0,0 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum S2ModelProvider {
|
||||
|
||||
OPEN_AI,
|
||||
HUGGING_FACE,
|
||||
LOCAL_AI,
|
||||
IN_PROCESS
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class S2ChatModelProvider {
|
||||
|
||||
public static ChatLanguageModel provide(LLMConfig llmConfig) {
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return chatLanguageModel;
|
||||
}
|
||||
if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.apiKey(llmConfig.keyDecrypt())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
} else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
throw new RuntimeException("unsupported provider: " + llmConfig.getProvider());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
public interface ChatLanguageModelFactory {
|
||||
ChatLanguageModel create(LLMConfig llmConfig);
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ChatLanguageModelProvider {
|
||||
private static final Map<String, ChatLanguageModelFactory> factories = new HashMap<>();
|
||||
|
||||
static {
|
||||
factories.put(ModelProvider.OPEN_AI.name(), new OpenAiChatModelFactory());
|
||||
factories.put(ModelProvider.LOCAL_AI.name(), new LocalAiChatModelFactory());
|
||||
factories.put(ModelProvider.OLLAMA.name(), new OllamaChatModelFactory());
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provide(LLMConfig llmConfig) {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
|
||||
ChatLanguageModelFactory factory = factories.get(llmConfig.getProvider().toUpperCase());
|
||||
if (factory != null) {
|
||||
return factory.create(llmConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported provider: " + llmConfig.getProvider());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class LocalAiChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
public enum ModelProvider {
|
||||
OPEN_AI,
|
||||
HUGGING_FACE,
|
||||
LOCAL_AI,
|
||||
IN_PROCESS,
|
||||
OLLAMA
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class OllamaChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return OllamaChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class OpenAiChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.apiKey(llmConfig.keyDecrypt())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user