[improvement][launcher]Introduce ollama local models to Text2SQLEval.

This commit is contained in:
jerryjzhang
2024-09-20 18:48:55 +08:00
parent a200483b5c
commit 26ca5300f4
2 changed files with 56 additions and 19 deletions

View File

@@ -133,7 +133,7 @@ public class Text2SQLEval extends BaseTest {
AgentConfig agentConfig = new AgentConfig();
agentConfig.getTools().add(getLLMQueryTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.GPT));
agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
agent.setMultiTurnConfig(multiTurnConfig);

View File

@@ -4,55 +4,92 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
public class LLMConfigUtils {
public enum LLMType {
GPT,
MOONSHOT,
DEEPSEEK,
QWEN,
GLM
OPENAI_GPT(false),
OPENAI_MOONSHOT(false),
OPENAI_DEEPSEEK(false),
OPENAI_QWEN(false),
OPENAI_GLM(false),
OLLAMA_LLAMA3(true),
OLLAMA_QWEN2(true),
OLLAMA_QWEN25(true);
private boolean isOllam;
LLMType(boolean isOllam) {
this.isOllam = isOllam;
}
}
public static ChatModelConfig getLLMConfig(LLMType type) {
String baseUrl;
String apiKey;
String apiKey = "";
String modelName;
double temperature = 0.0;
switch (type) {
case GLM:
case OLLAMA_LLAMA3:
baseUrl = "http://localhost:11434";
modelName = "llama3.1:8b";
break;
case OLLAMA_QWEN2:
baseUrl = "http://localhost:11434";
modelName = "qwen2:7b";
break;
case OLLAMA_QWEN25:
baseUrl = "http://localhost:11434";
modelName = "qwen2.5:7b";
break;
case OPENAI_GLM:
baseUrl = "https://open.bigmodel.cn/api/pas/v4/";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "glm-4";
break;
case MOONSHOT:
case OPENAI_MOONSHOT:
baseUrl = "https://api.moonshot.cn/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "moonshot-v1-8k";
break;
case DEEPSEEK:
case OPENAI_DEEPSEEK:
baseUrl = "https://api.deepseek.com";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "deepseek-coder";
break;
case QWEN:
case OPENAI_QWEN:
baseUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "qwen-turbo";
temperature = 0.01;
break;
case GPT:
case OPENAI_GPT:
default:
baseUrl = "https://api.openai.com/v1";
apiKey = "REPLACE_WITH_YOUR_KEY";
modelName = "gpt-3.5-turbo";
temperature = 0.0;
}
ChatModelConfig chatModel = new ChatModelConfig();
chatModel.setModelName(modelName);
chatModel.setBaseUrl(baseUrl);
chatModel.setApiKey(apiKey);
chatModel.setTemperature(temperature);
chatModel.setProvider("open_ai");
return chatModel;
ChatModelConfig chatModelConfig;
if (type.isOllam) {
chatModelConfig =
ChatModelConfig.builder()
.provider("ollama")
.baseUrl(baseUrl)
.modelName(modelName)
.temperature(temperature)
.timeOut(60000L)
.build();
} else {
chatModelConfig =
ChatModelConfig.builder()
.provider("open_ai")
.baseUrl(baseUrl)
.apiKey(apiKey)
.modelName(modelName)
.temperature(temperature)
.timeOut(60000L)
.build();
}
return chatModelConfig;
}
}