(improvement)(chat) Modify the default URL and model of zhipu, and enable search support for qwen. (#1578)

This commit is contained in:
lexluo09
2024-08-17 23:12:48 +08:00
committed by GitHub
parent 898c7100ba
commit 115cf19078
8 changed files with 40 additions and 17 deletions

View File

@@ -4,8 +4,8 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
import dev.ai4j.openai4j.chat.ChatCompletionModel;
import dev.langchain4j.model.dashscope.QwenModelName; import dev.langchain4j.model.dashscope.QwenModelName;
import dev.langchain4j.model.zhipu.ChatCompletionModel;
import dev.langchain4j.provider.AzureModelFactory; import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory; import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.LocalAiModelFactory; import dev.langchain4j.provider.LocalAiModelFactory;
@@ -51,6 +51,11 @@ public class ChatModelParameterConfig extends ParameterConfig {
"ModelName", "", "string", "ModelName", "", "string",
"对话模型配置", null, getModelNameDependency()); "对话模型配置", null, getModelNameDependency());
public static final Parameter CHAT_MODEL_ENABLE_SEARCH =
new Parameter("s2.chat.model.enableSearch", "false",
"是否启用搜索增强功能设为false表示不启用", "", "bool",
"对话模型配置", null, getEnableSearchDependency());
public static final Parameter CHAT_MODEL_TEMPERATURE = public static final Parameter CHAT_MODEL_TEMPERATURE =
new Parameter("s2.chat.model.temperature", "0.0", new Parameter("s2.chat.model.temperature", "0.0",
"Temperature", "", "Temperature", "",
@@ -66,7 +71,7 @@ public class ChatModelParameterConfig extends ParameterConfig {
return Lists.newArrayList( return Lists.newArrayList(
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT, CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
); );
} }
@@ -79,12 +84,14 @@ public class ChatModelParameterConfig extends ParameterConfig {
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT); String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT); String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY); String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
return ChatModelConfig.builder() return ChatModelConfig.builder()
.provider(chatModelProvider) .provider(chatModelProvider)
.baseUrl(chatModelBaseUrl) .baseUrl(chatModelBaseUrl)
.apiKey(chatModelApiKey) .apiKey(chatModelApiKey)
.modelName(chatModelName) .modelName(chatModelName)
.enableSearch(Boolean.valueOf(enableSearch))
.temperature(Double.valueOf(chatModelTemperature)) .temperature(Double.valueOf(chatModelTemperature))
.timeOut(Long.valueOf(chatModelTimeout)) .timeOut(Long.valueOf(chatModelTimeout))
.endpoint(endpoint) .endpoint(endpoint)
@@ -148,7 +155,7 @@ public class ChatModelParameterConfig extends ParameterConfig {
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo", OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
OllamaModelFactory.PROVIDER, "qwen:0.5b", OllamaModelFactory.PROVIDER, "qwen:0.5b",
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat", QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GPT_4.toString(), ZhipuModelFactory.PROVIDER, ChatCompletionModel.GLM_4.toString(),
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j", LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
AzureModelFactory.PROVIDER, "gpt-35-turbo", AzureModelFactory.PROVIDER, "gpt-35-turbo",
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
@@ -166,6 +173,13 @@ public class ChatModelParameterConfig extends ParameterConfig {
); );
} }
private static List<Parameter.Dependency> getEnableSearchDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false")
);
}
private static List<Parameter.Dependency> getSecretKeyDependency() { private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(), return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER), Lists.newArrayList(QianfanModelFactory.PROVIDER),

View File

@@ -102,7 +102,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL, DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL, QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_EMBEDDING_BASE_URL ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL
) )
); );
} }

View File

@@ -26,6 +26,7 @@ public class ChatModelConfig implements Serializable {
private Integer maxRetries = 3; private Integer maxRetries = 3;
private Boolean logRequests = false; private Boolean logRequests = false;
private Boolean logResponses = false; private Boolean logResponses = false;
private Boolean enableSearch = false;
public String keyDecrypt() { public String keyDecrypt() {
return AESEncryptionUtil.aesDecryptECB(getApiKey()); return AESEncryptionUtil.aesDecryptECB(getApiKey());

View File

@@ -0,0 +1,18 @@
package dev.langchain4j.model.zhipu;
public enum ChatCompletionModel {
GLM_4("glm-4"),
GLM_3_TURBO("glm-3-turbo"),
CHATGLM_TURBO("chatglm_turbo");
private final String value;
ChatCompletionModel(String value) {
this.value = value;
}
@Override
public String toString() {
return this.value;
}
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.model.zhipu; package dev.langchain4j.model.zhipu;
import dev.ai4j.openai4j.chat.ChatCompletionModel;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ChatMessage;
@@ -55,7 +54,7 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/"); this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
this.temperature = getOrDefault(temperature, 0.7); this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP; this.topP = topP;
this.model = getOrDefault(model, ChatCompletionModel.GPT_4.toString()); this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
this.maxRetries = getOrDefault(maxRetries, 3); this.maxRetries = getOrDefault(maxRetries, 3);
this.maxToken = getOrDefault(maxToken, 512); this.maxToken = getOrDefault(maxToken, 512);
this.client = ZhipuAiClient.builder() this.client = ZhipuAiClient.builder()

View File

@@ -13,7 +13,6 @@ import org.springframework.stereotype.Service;
public class DashscopeModelFactory implements ModelFactory, InitializingBean { public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE"; public static final String PROVIDER = "DASHSCOPE";
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"; public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
public static final String DEFAULT_COMPATIBLE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
@@ -24,6 +23,7 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
.temperature(modelConfig.getTemperature() == null ? 0L : .temperature(modelConfig.getTemperature() == null ? 0L :
modelConfig.getTemperature().floatValue()) modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP()) .topP(modelConfig.getTopP())
.enableSearch(modelConfig.getEnableSearch())
.build(); .build();
} }

View File

@@ -1,8 +0,0 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import dev.langchain4j.store.embedding.EmbeddingStore;
public interface EmbeddingStoreFactory {
EmbeddingStore createEmbeddingStore(EmbeddingStoreConfig config);
}

View File

@@ -12,8 +12,7 @@ import org.springframework.stereotype.Service;
@Service @Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean { public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU"; public static final String PROVIDER = "ZHIPU";
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4"; public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
public static final String DEFAULT_EMBEDDING_BASE_URL = "https://open.bigmodel.cn/";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {