(improvement)(headless)Optimize Text2SQL prompt, explicitly ask LLM not hallucinate columns.

This commit is contained in:
jerryjzhang
2024-08-09 19:27:38 +08:00
parent 24c63c93bb
commit ecc651e12d
2 changed files with 15 additions and 12 deletions

View File

@@ -16,6 +16,8 @@ import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.stream.Collectors;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@Disabled
public class Text2SQLEval extends BaseTest {
@@ -91,13 +93,13 @@ public class Text2SQLEval extends BaseTest {
@Test
public void test_term() throws Exception {
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().contains("用户");
assert result.getQueryColumns().get(1).getName().contains("停留时长");
assert result.getQueryResults().size() == 2;
assert result.getQueryColumns().size() >= 1;
assert result.getQueryColumns().stream()
.filter(c -> c.getName().contains("停留时长")).collect(Collectors.toList()).size() == 1;
assert result.getQueryResults().size() >= 1;
}
public static Agent getLLMAgent(boolean enableMultiturn) {
public Agent getLLMAgent(boolean enableMultiturn) {
Agent agent = new Agent();
agent.setName("Agent for Test");
AgentConfig agentConfig = new AgentConfig();
@@ -110,7 +112,7 @@ public class Text2SQLEval extends BaseTest {
return agent;
}
private static RuleParserTool getLLMQueryTool() {
private RuleParserTool getLLMQueryTool() {
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_LLM);
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
@@ -126,7 +128,7 @@ public class Text2SQLEval extends BaseTest {
GLM
}
private static ChatModelConfig getLLMConfig(LLMType type) {
protected ChatModelConfig getLLMConfig(LLMType type) {
String baseUrl;
String apiKey;
String modelName;