(improvement)(headless) Add support for the Ollama provider in the frontend and optimize the code (#1270)

This commit is contained in:
lexluo09
2024-06-28 17:29:58 +08:00
committed by GitHub
parent 7564256b0a
commit 528491717b
15 changed files with 165 additions and 62 deletions

View File

@@ -1,9 +0,0 @@
package com.tencent.supersonic.common.pojo.enums;
public enum S2ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_PROCESS
}

View File

@@ -1,41 +0,0 @@
package com.tencent.supersonic.common.util;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.apache.commons.lang3.StringUtils;
import java.time.Duration;
public class S2ChatModelProvider {
public static ChatLanguageModel provide(LLMConfig llmConfig) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return chatLanguageModel;
}
if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return OpenAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.apiKey(llmConfig.keyDecrypt())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
} else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return LocalAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
throw new RuntimeException("unsupported provider: " + llmConfig.getProvider());
}
}

View File

@@ -0,0 +1,8 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
public interface ChatLanguageModelFactory {
ChatLanguageModel create(LLMConfig llmConfig);
}

View File

@@ -0,0 +1,33 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.Map;
public class ChatLanguageModelProvider {
private static final Map<String, ChatLanguageModelFactory> factories = new HashMap<>();
static {
factories.put(ModelProvider.OPEN_AI.name(), new OpenAiChatModelFactory());
factories.put(ModelProvider.LOCAL_AI.name(), new LocalAiChatModelFactory());
factories.put(ModelProvider.OLLAMA.name(), new OllamaChatModelFactory());
}
public static ChatLanguageModel provide(LLMConfig llmConfig) {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ChatLanguageModelFactory factory = factories.get(llmConfig.getProvider().toUpperCase());
if (factory != null) {
return factory.create(llmConfig);
}
throw new RuntimeException("Unsupported provider: " + llmConfig.getProvider());
}
}

View File

@@ -0,0 +1,20 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import java.time.Duration;
public class LocalAiChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return LocalAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}

View File

@@ -0,0 +1,9 @@
package dev.langchain4j.model.provider;
public enum ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_PROCESS,
OLLAMA
}

View File

@@ -0,0 +1,20 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import java.time.Duration;
public class OllamaChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return OllamaChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}

View File

@@ -0,0 +1,21 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.time.Duration;
public class OpenAiChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return OpenAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.apiKey(llmConfig.keyDecrypt())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}