(improvement)(chat) Support large models qianfan, zhipu, Azure, LocalAi, Dashscope, and handle the apiKey configuration as hidden. (#1552)

This commit is contained in:
lexluo09
2024-08-11 23:28:24 +08:00
committed by GitHub
parent 2f2f493d17
commit 8b01dac8d4
12 changed files with 315 additions and 184 deletions

View File

@@ -14,15 +14,19 @@ import java.time.Duration;
@Service
public class AzureModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "AZURE";
public static final String DEFAULT_BASE_URL = "https://xxxx.openai.azure.com/";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl())
.endpoint(modelConfig.getEndpoint())
.apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()));
.maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
.logRequestsAndResponses(modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
return builder.build();
}

View File

@@ -12,6 +12,7 @@ import org.springframework.stereotype.Service;
@Service
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE";
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
@@ -21,6 +22,7 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature() == null ? 0L :
modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP())
.build();
}

View File

@@ -14,7 +14,7 @@ import java.time.Duration;
@Service
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "LOCAL_AI";
public static final String DEFAULT_BASE_URL = "http://localhost:8080";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return LocalAiChatModel
@@ -23,6 +23,10 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.topP(modelConfig.getTopP())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.maxRetries(modelConfig.getMaxRetries())
.build();
}
@@ -31,6 +35,9 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
return LocalAiEmbeddingModel.builder()
.baseUrl(embeddingModel.getBaseUrl())
.modelName(embeddingModel.getModelName())
.maxRetries(embeddingModel.getMaxRetries())
.logRequests(embeddingModel.getLogRequests())
.logResponses(embeddingModel.getLogResponses())
.build();
}

View File

@@ -13,7 +13,9 @@ import java.time.Duration;
@Service
public class OllamaModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OLLAMA";
public static final String DEFAULT_BASE_URL = "http://localhost:11434";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
@@ -23,6 +25,10 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
}

View File

@@ -13,7 +13,9 @@ import java.time.Duration;
@Service
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OPEN_AI";
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
@@ -23,7 +25,11 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName())
.apiKey(modelConfig.keyDecrypt())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
}

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.qianfan.QianfanChatModel;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -12,10 +13,22 @@ import org.springframework.stereotype.Service;
public class QianfanModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "QIANFAN";
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null;
return QianfanChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.secretKey(modelConfig.getSecretKey())
.endpoint(modelConfig.getEndpoint())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
}
@Override

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -11,10 +12,20 @@ import org.springframework.stereotype.Service;
@Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU";
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null;
return ZhipuAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.model(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
}
@Override