mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) Modify the default URL and model of zhipu, and enable search support for qwen. (#1578)
This commit is contained in:
@@ -4,8 +4,8 @@ import com.google.common.collect.ImmutableMap;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Parameter;
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
import dev.ai4j.openai4j.chat.ChatCompletionModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
import dev.langchain4j.model.dashscope.QwenModelName;
|
||||||
|
import dev.langchain4j.model.zhipu.ChatCompletionModel;
|
||||||
import dev.langchain4j.provider.AzureModelFactory;
|
import dev.langchain4j.provider.AzureModelFactory;
|
||||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||||
@@ -51,6 +51,11 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
"ModelName", "", "string",
|
"ModelName", "", "string",
|
||||||
"对话模型配置", null, getModelNameDependency());
|
"对话模型配置", null, getModelNameDependency());
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_ENABLE_SEARCH =
|
||||||
|
new Parameter("s2.chat.model.enableSearch", "false",
|
||||||
|
"是否启用搜索增强功能,设为false表示不启用", "", "bool",
|
||||||
|
"对话模型配置", null, getEnableSearchDependency());
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||||
new Parameter("s2.chat.model.temperature", "0.0",
|
new Parameter("s2.chat.model.temperature", "0.0",
|
||||||
"Temperature", "",
|
"Temperature", "",
|
||||||
@@ -66,7 +71,7 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
return Lists.newArrayList(
|
return Lists.newArrayList(
|
||||||
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||||
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
||||||
CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,12 +84,14 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
||||||
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
|
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
|
||||||
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
|
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
|
||||||
|
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
|
||||||
|
|
||||||
return ChatModelConfig.builder()
|
return ChatModelConfig.builder()
|
||||||
.provider(chatModelProvider)
|
.provider(chatModelProvider)
|
||||||
.baseUrl(chatModelBaseUrl)
|
.baseUrl(chatModelBaseUrl)
|
||||||
.apiKey(chatModelApiKey)
|
.apiKey(chatModelApiKey)
|
||||||
.modelName(chatModelName)
|
.modelName(chatModelName)
|
||||||
|
.enableSearch(Boolean.valueOf(enableSearch))
|
||||||
.temperature(Double.valueOf(chatModelTemperature))
|
.temperature(Double.valueOf(chatModelTemperature))
|
||||||
.timeOut(Long.valueOf(chatModelTimeout))
|
.timeOut(Long.valueOf(chatModelTimeout))
|
||||||
.endpoint(endpoint)
|
.endpoint(endpoint)
|
||||||
@@ -148,7 +155,7 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
|
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
|
||||||
OllamaModelFactory.PROVIDER, "qwen:0.5b",
|
OllamaModelFactory.PROVIDER, "qwen:0.5b",
|
||||||
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
|
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
|
||||||
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GPT_4.toString(),
|
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GLM_4.toString(),
|
||||||
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
|
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
|
||||||
AzureModelFactory.PROVIDER, "gpt-35-turbo",
|
AzureModelFactory.PROVIDER, "gpt-35-turbo",
|
||||||
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
|
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
|
||||||
@@ -166,6 +173,13 @@ public class ChatModelParameterConfig extends ParameterConfig {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
|||||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_EMBEDDING_BASE_URL
|
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ public class ChatModelConfig implements Serializable {
|
|||||||
private Integer maxRetries = 3;
|
private Integer maxRetries = 3;
|
||||||
private Boolean logRequests = false;
|
private Boolean logRequests = false;
|
||||||
private Boolean logResponses = false;
|
private Boolean logResponses = false;
|
||||||
|
private Boolean enableSearch = false;
|
||||||
|
|
||||||
public String keyDecrypt() {
|
public String keyDecrypt() {
|
||||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package dev.langchain4j.model.zhipu;
|
||||||
|
|
||||||
|
public enum ChatCompletionModel {
|
||||||
|
GLM_4("glm-4"),
|
||||||
|
GLM_3_TURBO("glm-3-turbo"),
|
||||||
|
CHATGLM_TURBO("chatglm_turbo");
|
||||||
|
|
||||||
|
private final String value;
|
||||||
|
|
||||||
|
ChatCompletionModel(String value) {
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return this.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
package dev.langchain4j.model.zhipu;
|
package dev.langchain4j.model.zhipu;
|
||||||
|
|
||||||
import dev.ai4j.openai4j.chat.ChatCompletionModel;
|
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.data.message.ChatMessage;
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
@@ -55,7 +54,7 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
|
|||||||
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
|
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
|
||||||
this.temperature = getOrDefault(temperature, 0.7);
|
this.temperature = getOrDefault(temperature, 0.7);
|
||||||
this.topP = topP;
|
this.topP = topP;
|
||||||
this.model = getOrDefault(model, ChatCompletionModel.GPT_4.toString());
|
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
|
||||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||||
this.maxToken = getOrDefault(maxToken, 512);
|
this.maxToken = getOrDefault(maxToken, 512);
|
||||||
this.client = ZhipuAiClient.builder()
|
this.client = ZhipuAiClient.builder()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import org.springframework.stereotype.Service;
|
|||||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||||
public static final String PROVIDER = "DASHSCOPE";
|
public static final String PROVIDER = "DASHSCOPE";
|
||||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
|
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
|
||||||
public static final String DEFAULT_COMPATIBLE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
@@ -24,6 +23,7 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
|||||||
.temperature(modelConfig.getTemperature() == null ? 0L :
|
.temperature(modelConfig.getTemperature() == null ? 0L :
|
||||||
modelConfig.getTemperature().floatValue())
|
modelConfig.getTemperature().floatValue())
|
||||||
.topP(modelConfig.getTopP())
|
.topP(modelConfig.getTopP())
|
||||||
|
.enableSearch(modelConfig.getEnableSearch())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
||||||
|
|
||||||
public interface EmbeddingStoreFactory {
|
|
||||||
EmbeddingStore createEmbeddingStore(EmbeddingStoreConfig config);
|
|
||||||
}
|
|
||||||
@@ -12,8 +12,7 @@ import org.springframework.stereotype.Service;
|
|||||||
@Service
|
@Service
|
||||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||||
public static final String PROVIDER = "ZHIPU";
|
public static final String PROVIDER = "ZHIPU";
|
||||||
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4";
|
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
|
||||||
public static final String DEFAULT_EMBEDDING_BASE_URL = "https://open.bigmodel.cn/";
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
|
|||||||
Reference in New Issue
Block a user