mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless)Optimize Text2SQL prompt, explicitly ask LLM not hallucinate columns.
This commit is contained in:
@@ -31,11 +31,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||||
+ "by executing the SQL query against underlying database.\n"
|
+ "by executing the SQL query against underlying database.\n"
|
||||||
+ "#Rules:"
|
+ "#Rules:"
|
||||||
+ "1.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
+ "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate."
|
||||||
+ "2.ALWAYS calculate the absolute date range by yourself."
|
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||||
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
+ "3.ALWAYS calculate the absolute date range by yourself."
|
||||||
+ "4.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
|
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||||
+ "5.ONLY respond with the converted SQL statement.\n"
|
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
|
||||||
|
+ "6.ONLY respond with the converted SQL statement.\n"
|
||||||
+ "#Exemplars:\n{{exemplar}}"
|
+ "#Exemplars:\n{{exemplar}}"
|
||||||
+ "#Question:{{question}} #Schema:{{schema}} #SideInfo:{{information}} #SQL:";
|
+ "#Question:{{question}} #Schema:{{schema}} #SideInfo:{{information}} #SQL:";
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import org.junit.jupiter.api.Disabled;
|
|||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.TestInstance;
|
import org.junit.jupiter.api.TestInstance;
|
||||||
|
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||||
@Disabled
|
@Disabled
|
||||||
public class Text2SQLEval extends BaseTest {
|
public class Text2SQLEval extends BaseTest {
|
||||||
@@ -91,13 +93,13 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
@Test
|
@Test
|
||||||
public void test_term() throws Exception {
|
public void test_term() throws Exception {
|
||||||
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
|
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
|
||||||
assert result.getQueryColumns().size() == 2;
|
assert result.getQueryColumns().size() >= 1;
|
||||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
assert result.getQueryColumns().stream()
|
||||||
assert result.getQueryColumns().get(1).getName().contains("停留时长");
|
.filter(c -> c.getName().contains("停留时长")).collect(Collectors.toList()).size() == 1;
|
||||||
assert result.getQueryResults().size() == 2;
|
assert result.getQueryResults().size() >= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Agent getLLMAgent(boolean enableMultiturn) {
|
public Agent getLLMAgent(boolean enableMultiturn) {
|
||||||
Agent agent = new Agent();
|
Agent agent = new Agent();
|
||||||
agent.setName("Agent for Test");
|
agent.setName("Agent for Test");
|
||||||
AgentConfig agentConfig = new AgentConfig();
|
AgentConfig agentConfig = new AgentConfig();
|
||||||
@@ -110,7 +112,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
return agent;
|
return agent;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static RuleParserTool getLLMQueryTool() {
|
private RuleParserTool getLLMQueryTool() {
|
||||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||||
ruleQueryTool.setType(AgentToolType.NL2SQL_LLM);
|
ruleQueryTool.setType(AgentToolType.NL2SQL_LLM);
|
||||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
|
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
|
||||||
@@ -126,7 +128,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
GLM
|
GLM
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ChatModelConfig getLLMConfig(LLMType type) {
|
protected ChatModelConfig getLLMConfig(LLMType type) {
|
||||||
String baseUrl;
|
String baseUrl;
|
||||||
String apiKey;
|
String apiKey;
|
||||||
String modelName;
|
String modelName;
|
||||||
|
|||||||
Reference in New Issue
Block a user