diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java index 2b3f743dd..653f00ce8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java @@ -4,8 +4,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Parameter; -import dev.ai4j.openai4j.chat.ChatCompletionModel; import dev.langchain4j.model.dashscope.QwenModelName; +import dev.langchain4j.model.zhipu.ChatCompletionModel; import dev.langchain4j.provider.AzureModelFactory; import dev.langchain4j.provider.DashscopeModelFactory; import dev.langchain4j.provider.LocalAiModelFactory; @@ -51,6 +51,11 @@ public class ChatModelParameterConfig extends ParameterConfig { "ModelName", "", "string", "对话模型配置", null, getModelNameDependency()); + public static final Parameter CHAT_MODEL_ENABLE_SEARCH = + new Parameter("s2.chat.model.enableSearch", "false", + "是否启用搜索增强功能,设为false表示不启用", "", "bool", + "对话模型配置", null, getEnableSearchDependency()); + public static final Parameter CHAT_MODEL_TEMPERATURE = new Parameter("s2.chat.model.temperature", "0.0", "Temperature", "", @@ -66,7 +71,7 @@ public class ChatModelParameterConfig extends ParameterConfig { return Lists.newArrayList( CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT, CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, - CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT + CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT ); } @@ -79,12 +84,14 @@ public class ChatModelParameterConfig extends ParameterConfig { String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT); String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT); String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY); + String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH); return ChatModelConfig.builder() .provider(chatModelProvider) .baseUrl(chatModelBaseUrl) .apiKey(chatModelApiKey) .modelName(chatModelName) + .enableSearch(Boolean.valueOf(enableSearch)) .temperature(Double.valueOf(chatModelTemperature)) .timeOut(Long.valueOf(chatModelTimeout)) .endpoint(endpoint) @@ -148,7 +155,7 @@ public class ChatModelParameterConfig extends ParameterConfig { OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo", OllamaModelFactory.PROVIDER, "qwen:0.5b", QianfanModelFactory.PROVIDER, "Llama-2-70b-chat", - ZhipuModelFactory.PROVIDER, ChatCompletionModel.GPT_4.toString(), + ZhipuModelFactory.PROVIDER, ChatCompletionModel.GLM_4.toString(), LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j", AzureModelFactory.PROVIDER, "gpt-35-turbo", DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS @@ -166,6 +173,13 @@ public class ChatModelParameterConfig extends ParameterConfig { ); } + private static List getEnableSearchDependency() { + return getDependency(CHAT_MODEL_PROVIDER.getName(), + Lists.newArrayList(DashscopeModelFactory.PROVIDER), + ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false") + ); + } + private static List getSecretKeyDependency() { return getDependency(CHAT_MODEL_PROVIDER.getName(), Lists.newArrayList(QianfanModelFactory.PROVIDER), diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java index 03ebb7947..2d1ff6a37 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java @@ -102,7 +102,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig { AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL, QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL, - ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_EMBEDDING_BASE_URL + ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL ) ); } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java index e96478e29..6bacb3291 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java @@ -26,6 +26,7 @@ public class ChatModelConfig implements Serializable { private Integer maxRetries = 3; private Boolean logRequests = false; private Boolean logResponses = false; + private Boolean enableSearch = false; public String keyDecrypt() { return AESEncryptionUtil.aesDecryptECB(getApiKey()); diff --git a/common/src/main/java/dev/langchain4j/model/zhipu/ChatCompletionModel.java b/common/src/main/java/dev/langchain4j/model/zhipu/ChatCompletionModel.java new file mode 100644 index 000000000..b0496e819 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/model/zhipu/ChatCompletionModel.java @@ -0,0 +1,18 @@ +package dev.langchain4j.model.zhipu; + +public enum ChatCompletionModel { + GLM_4("glm-4"), + GLM_3_TURBO("glm-3-turbo"), + CHATGLM_TURBO("chatglm_turbo"); + + private final String value; + + ChatCompletionModel(String value) { + this.value = value; + } + + @Override + public String toString() { + return this.value; + } +} \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java b/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java index fbdb39061..7c887e245 100644 --- a/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java +++ b/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java @@ -1,6 +1,5 @@ package dev.langchain4j.model.zhipu; -import dev.ai4j.openai4j.chat.ChatCompletionModel; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; @@ -55,7 +54,7 @@ public class ZhipuAiChatModel implements ChatLanguageModel { this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/"); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; - this.model = getOrDefault(model, ChatCompletionModel.GPT_4.toString()); + this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString()); this.maxRetries = getOrDefault(maxRetries, 3); this.maxToken = getOrDefault(maxToken, 512); this.client = ZhipuAiClient.builder() diff --git a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java index d7741623d..13de6ee4b 100644 --- a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java @@ -13,7 +13,6 @@ import org.springframework.stereotype.Service; public class DashscopeModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "DASHSCOPE"; public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"; - public static final String DEFAULT_COMPATIBLE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { @@ -24,6 +23,7 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean { .temperature(modelConfig.getTemperature() == null ? 0L : modelConfig.getTemperature().floatValue()) .topP(modelConfig.getTopP()) + .enableSearch(modelConfig.getEnableSearch()) .build(); } diff --git a/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreFactory.java deleted file mode 100644 index fea4cf7ad..000000000 --- a/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreFactory.java +++ /dev/null @@ -1,8 +0,0 @@ -package dev.langchain4j.provider; - -import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; -import dev.langchain4j.store.embedding.EmbeddingStore; - -public interface EmbeddingStoreFactory { - EmbeddingStore createEmbeddingStore(EmbeddingStoreConfig config); -} diff --git a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java index 68594eee1..3620500cc 100644 --- a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java @@ -12,8 +12,7 @@ import org.springframework.stereotype.Service; @Service public class ZhipuModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "ZHIPU"; - public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4"; - public static final String DEFAULT_EMBEDDING_BASE_URL = "https://open.bigmodel.cn/"; + public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {