[improvement][launcher]Introduce ollama local models to Text2SQLEval.

This commit is contained in:
jerryjzhang
2024-09-20 18:48:55 +08:00
parent a200483b5c
commit 26ca5300f4
2 changed files with 56 additions and 19 deletions

View File

@@ -133,7 +133,7 @@ public class Text2SQLEval extends BaseTest {
AgentConfig agentConfig = new AgentConfig(); AgentConfig agentConfig = new AgentConfig();
agentConfig.getTools().add(getLLMQueryTool()); agentConfig.getTools().add(getLLMQueryTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.GPT)); agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(enableMultiturn); multiTurnConfig.setEnableMultiTurn(enableMultiturn);
agent.setMultiTurnConfig(multiTurnConfig); agent.setMultiTurnConfig(multiTurnConfig);

View File

@@ -4,55 +4,92 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
public class LLMConfigUtils { public class LLMConfigUtils {
public enum LLMType { public enum LLMType {
GPT, OPENAI_GPT(false),
MOONSHOT, OPENAI_MOONSHOT(false),
DEEPSEEK, OPENAI_DEEPSEEK(false),
QWEN, OPENAI_QWEN(false),
GLM OPENAI_GLM(false),
OLLAMA_LLAMA3(true),
OLLAMA_QWEN2(true),
OLLAMA_QWEN25(true);
private boolean isOllam;
LLMType(boolean isOllam) {
this.isOllam = isOllam;
}
} }
public static ChatModelConfig getLLMConfig(LLMType type) { public static ChatModelConfig getLLMConfig(LLMType type) {
String baseUrl; String baseUrl;
String apiKey; String apiKey = "";
String modelName; String modelName;
double temperature = 0.0; double temperature = 0.0;
switch (type) { switch (type) {
case GLM: case OLLAMA_LLAMA3:
baseUrl = "http://localhost:11434";
modelName = "llama3.1:8b";
break;
case OLLAMA_QWEN2:
baseUrl = "http://localhost:11434";
modelName = "qwen2:7b";
break;
case OLLAMA_QWEN25:
baseUrl = "http://localhost:11434";
modelName = "qwen2.5:7b";
break;
case OPENAI_GLM:
baseUrl = "https://open.bigmodel.cn/api/pas/v4/"; baseUrl = "https://open.bigmodel.cn/api/pas/v4/";
apiKey = "REPLACE_WITH_YOUR_KEY"; apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "glm-4"; modelName = "glm-4";
break; break;
case MOONSHOT: case OPENAI_MOONSHOT:
baseUrl = "https://api.moonshot.cn/v1"; baseUrl = "https://api.moonshot.cn/v1";
apiKey = "REPLACE_WITH_YOUR_KEY"; apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "moonshot-v1-8k"; modelName = "moonshot-v1-8k";
break; break;
case DEEPSEEK: case OPENAI_DEEPSEEK:
baseUrl = "https://api.deepseek.com"; baseUrl = "https://api.deepseek.com";
apiKey = "REPLACE_WITH_YOUR_KEY"; apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "deepseek-coder"; modelName = "deepseek-coder";
break; break;
case QWEN: case OPENAI_QWEN:
baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1"; baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1";
apiKey = "REPLACE_WITH_YOUR_KEY"; apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "qwen-turbo"; modelName = "qwen-turbo";
temperature = 0.01; temperature = 0.01;
break; break;
case GPT: case OPENAI_GPT:
default: default:
baseUrl = "https://api.openai.com/v1"; baseUrl = "https://api.openai.com/v1";
apiKey = "REPLACE_WITH_YOUR_KEY"; apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "gpt-3.5-turbo"; modelName = "gpt-3.5-turbo";
temperature = 0.0; temperature = 0.0;
} }
ChatModelConfig chatModel = new ChatModelConfig();
chatModel.setModelName(modelName);
chatModel.setBaseUrl(baseUrl);
chatModel.setApiKey(apiKey);
chatModel.setTemperature(temperature);
chatModel.setProvider("open_ai");
return chatModel; ChatModelConfig chatModelConfig;
if (type.isOllam) {
chatModelConfig =
ChatModelConfig.builder()
.provider("ollama")
.baseUrl(baseUrl)
.modelName(modelName)
.temperature(temperature)
.timeOut(60000L)
.build();
} else {
chatModelConfig =
ChatModelConfig.builder()
.provider("open_ai")
.baseUrl(baseUrl)
.apiKey(apiKey)
.modelName(modelName)
.temperature(temperature)
.timeOut(60000L)
.build();
}
return chatModelConfig;
} }
} }