From ecc651e12d29289cb5a4ffa259fc3e7a24bb63b0 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Fri, 9 Aug 2024 19:27:38 +0800 Subject: [PATCH] (improvement)(headless)Optimize Text2SQL prompt, explicitly ask LLM not hallucinate columns. --- .../chat/parser/llm/OnePassSCSqlGenStrategy.java | 11 ++++++----- .../supersonic/evaluation/Text2SQLEval.java | 16 +++++++++------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index d013cf72c..8f92c358e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -31,11 +31,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "please convert it to a SQL query so that relevant data could be returned " + "by executing the SQL query against underlying database.\n" + "#Rules:" - + "1.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." - + "2.ALWAYS calculate the absolute date range by yourself." - + "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." - + "4.DO NOT miss the AGGREGATE operator of metrics, always add it if needed." - + "5.ONLY respond with the converted SQL statement.\n" + + "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate." + + "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." + + "3.ALWAYS calculate the absolute date range by yourself." + + "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + + "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed." + + "6.ONLY respond with the converted SQL statement.\n" + "#Exemplars:\n{{exemplar}}" + "#Question:{{question}} #Schema:{{schema}} #SideInfo:{{information}} #SQL:"; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index ac91c032b..28138075e 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -16,6 +16,8 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import java.util.stream.Collectors; + @TestInstance(TestInstance.Lifecycle.PER_CLASS) @Disabled public class Text2SQLEval extends BaseTest { @@ -91,13 +93,13 @@ public class Text2SQLEval extends BaseTest { @Test public void test_term() throws Exception { QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId); - assert result.getQueryColumns().size() == 2; - assert result.getQueryColumns().get(0).getName().contains("用户"); - assert result.getQueryColumns().get(1).getName().contains("停留时长"); - assert result.getQueryResults().size() == 2; + assert result.getQueryColumns().size() >= 1; + assert result.getQueryColumns().stream() + .filter(c -> c.getName().contains("停留时长")).collect(Collectors.toList()).size() == 1; + assert result.getQueryResults().size() >= 1; } - public static Agent getLLMAgent(boolean enableMultiturn) { + public Agent getLLMAgent(boolean enableMultiturn) { Agent agent = new Agent(); agent.setName("Agent for Test"); AgentConfig agentConfig = new AgentConfig(); @@ -110,7 +112,7 @@ public class Text2SQLEval extends BaseTest { return agent; } - private static RuleParserTool getLLMQueryTool() { + private RuleParserTool getLLMQueryTool() { RuleParserTool ruleQueryTool = new RuleParserTool(); ruleQueryTool.setType(AgentToolType.NL2SQL_LLM); ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L)); @@ -126,7 +128,7 @@ public class Text2SQLEval extends BaseTest { GLM } - private static ChatModelConfig getLLMConfig(LLMType type) { + protected ChatModelConfig getLLMConfig(LLMType type) { String baseUrl; String apiKey; String modelName;