From 26ca5300f4791c88a003dcf473a9535946f65b58 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Fri, 20 Sep 2024 18:48:55 +0800 Subject: [PATCH] [improvement][launcher]Introduce ollama local models to `Text2SQLEval`. --- .../supersonic/evaluation/Text2SQLEval.java | 2 +- .../supersonic/util/LLMConfigUtils.java | 73 ++++++++++++++----- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 035835b8a..53deb9fde 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -133,7 +133,7 @@ public class Text2SQLEval extends BaseTest { AgentConfig agentConfig = new AgentConfig(); agentConfig.getTools().add(getLLMQueryTool()); agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); - agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.GPT)); + agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3)); MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); multiTurnConfig.setEnableMultiTurn(enableMultiturn); agent.setMultiTurnConfig(multiTurnConfig); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java index 4dc6bb32f..0f43b415c 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java @@ -4,55 +4,92 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig; public class LLMConfigUtils { public enum LLMType { - GPT, - MOONSHOT, - DEEPSEEK, - QWEN, - GLM + OPENAI_GPT(false), + OPENAI_MOONSHOT(false), + OPENAI_DEEPSEEK(false), + OPENAI_QWEN(false), + OPENAI_GLM(false), + OLLAMA_LLAMA3(true), + OLLAMA_QWEN2(true), + OLLAMA_QWEN25(true); + + private boolean isOllam; + + LLMType(boolean isOllam) { + this.isOllam = isOllam; + } } public static ChatModelConfig getLLMConfig(LLMType type) { String baseUrl; - String apiKey; + String apiKey = ""; String modelName; double temperature = 0.0; switch (type) { - case GLM: + case OLLAMA_LLAMA3: + baseUrl = "http://localhost:11434"; + modelName = "llama3.1:8b"; + break; + case OLLAMA_QWEN2: + baseUrl = "http://localhost:11434"; + modelName = "qwen2:7b"; + break; + case OLLAMA_QWEN25: + baseUrl = "http://localhost:11434"; + modelName = "qwen2.5:7b"; + break; + case OPENAI_GLM: baseUrl = "https://open.bigmodel.cn/api/pas/v4/"; apiKey = "REPLACE_WITH_YOUR_KEY"; modelName = "glm-4"; break; - case MOONSHOT: + case OPENAI_MOONSHOT: baseUrl = "https://api.moonshot.cn/v1"; apiKey = "REPLACE_WITH_YOUR_KEY"; modelName = "moonshot-v1-8k"; break; - case DEEPSEEK: + case OPENAI_DEEPSEEK: baseUrl = "https://api.deepseek.com"; apiKey = "REPLACE_WITH_YOUR_KEY"; modelName = "deepseek-coder"; break; - case QWEN: + case OPENAI_QWEN: baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1"; apiKey = "REPLACE_WITH_YOUR_KEY"; modelName = "qwen-turbo"; temperature = 0.01; break; - case GPT: + case OPENAI_GPT: default: baseUrl = "https://api.openai.com/v1"; apiKey = "REPLACE_WITH_YOUR_KEY"; modelName = "gpt-3.5-turbo"; temperature = 0.0; } - ChatModelConfig chatModel = new ChatModelConfig(); - chatModel.setModelName(modelName); - chatModel.setBaseUrl(baseUrl); - chatModel.setApiKey(apiKey); - chatModel.setTemperature(temperature); - chatModel.setProvider("open_ai"); - return chatModel; + ChatModelConfig chatModelConfig; + if (type.isOllam) { + chatModelConfig = + ChatModelConfig.builder() + .provider("ollama") + .baseUrl(baseUrl) + .modelName(modelName) + .temperature(temperature) + .timeOut(60000L) + .build(); + } else { + chatModelConfig = + ChatModelConfig.builder() + .provider("open_ai") + .baseUrl(baseUrl) + .apiKey(apiKey) + .modelName(modelName) + .temperature(temperature) + .timeOut(60000L) + .build(); + } + + return chatModelConfig; } }