(improvement)(headless)Optimize Text2SQL prompt, explicitly ask LLM not hallucinate columns.

This commit is contained in:
jerryjzhang
2024-08-09 19:27:38 +08:00
parent 24c63c93bb
commit ecc651e12d
2 changed files with 15 additions and 12 deletions

View File

@@ -31,11 +31,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "please convert it to a SQL query so that relevant data could be returned " + "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database.\n" + "by executing the SQL query against underlying database.\n"
+ "#Rules:" + "#Rules:"
+ "1.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." + "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate."
+ "2.ALWAYS calculate the absolute date range by yourself." + "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + "3.ALWAYS calculate the absolute date range by yourself."
+ "4.DO NOT miss the AGGREGATE operator of metrics, always add it if needed." + "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "5.ONLY respond with the converted SQL statement.\n" + "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
+ "6.ONLY respond with the converted SQL statement.\n"
+ "#Exemplars:\n{{exemplar}}" + "#Exemplars:\n{{exemplar}}"
+ "#Question:{{question}} #Schema:{{schema}} #SideInfo:{{information}} #SQL:"; + "#Question:{{question}} #Schema:{{schema}} #SideInfo:{{information}} #SQL:";

View File

@@ -16,6 +16,8 @@ import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance;
import java.util.stream.Collectors;
@TestInstance(TestInstance.Lifecycle.PER_CLASS) @TestInstance(TestInstance.Lifecycle.PER_CLASS)
@Disabled @Disabled
public class Text2SQLEval extends BaseTest { public class Text2SQLEval extends BaseTest {
@@ -91,13 +93,13 @@ public class Text2SQLEval extends BaseTest {
@Test @Test
public void test_term() throws Exception { public void test_term() throws Exception {
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId); QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
assert result.getQueryColumns().size() == 2; assert result.getQueryColumns().size() >= 1;
assert result.getQueryColumns().get(0).getName().contains("用户"); assert result.getQueryColumns().stream()
assert result.getQueryColumns().get(1).getName().contains("停留时长"); .filter(c -> c.getName().contains("停留时长")).collect(Collectors.toList()).size() == 1;
assert result.getQueryResults().size() == 2; assert result.getQueryResults().size() >= 1;
} }
public static Agent getLLMAgent(boolean enableMultiturn) { public Agent getLLMAgent(boolean enableMultiturn) {
Agent agent = new Agent(); Agent agent = new Agent();
agent.setName("Agent for Test"); agent.setName("Agent for Test");
AgentConfig agentConfig = new AgentConfig(); AgentConfig agentConfig = new AgentConfig();
@@ -110,7 +112,7 @@ public class Text2SQLEval extends BaseTest {
return agent; return agent;
} }
private static RuleParserTool getLLMQueryTool() { private RuleParserTool getLLMQueryTool() {
RuleParserTool ruleQueryTool = new RuleParserTool(); RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_LLM); ruleQueryTool.setType(AgentToolType.NL2SQL_LLM);
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L)); ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
@@ -126,7 +128,7 @@ public class Text2SQLEval extends BaseTest {
GLM GLM
} }
private static ChatModelConfig getLLMConfig(LLMType type) { protected ChatModelConfig getLLMConfig(LLMType type) {
String baseUrl; String baseUrl;
String apiKey; String apiKey;
String modelName; String modelName;