mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(headless)Optimize Text2SQL prompt, explicitly ask LLM not hallucinate columns.
This commit is contained in:
@@ -31,11 +31,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||
+ "by executing the SQL query against underlying database.\n"
|
||||
+ "#Rules:"
|
||||
+ "1.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "2.ALWAYS calculate the absolute date range by yourself."
|
||||
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "4.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
|
||||
+ "5.ONLY respond with the converted SQL statement.\n"
|
||||
+ "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate."
|
||||
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "3.ALWAYS calculate the absolute date range by yourself."
|
||||
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
|
||||
+ "6.ONLY respond with the converted SQL statement.\n"
|
||||
+ "#Exemplars:\n{{exemplar}}"
|
||||
+ "#Question:{{question}} #Schema:{{schema}} #SideInfo:{{information}} #SQL:";
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@ import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
@Disabled
|
||||
public class Text2SQLEval extends BaseTest {
|
||||
@@ -91,13 +93,13 @@ public class Text2SQLEval extends BaseTest {
|
||||
@Test
|
||||
public void test_term() throws Exception {
|
||||
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
||||
assert result.getQueryColumns().get(1).getName().contains("停留时长");
|
||||
assert result.getQueryResults().size() == 2;
|
||||
assert result.getQueryColumns().size() >= 1;
|
||||
assert result.getQueryColumns().stream()
|
||||
.filter(c -> c.getName().contains("停留时长")).collect(Collectors.toList()).size() == 1;
|
||||
assert result.getQueryResults().size() >= 1;
|
||||
}
|
||||
|
||||
public static Agent getLLMAgent(boolean enableMultiturn) {
|
||||
public Agent getLLMAgent(boolean enableMultiturn) {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
@@ -110,7 +112,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
return agent;
|
||||
}
|
||||
|
||||
private static RuleParserTool getLLMQueryTool() {
|
||||
private RuleParserTool getLLMQueryTool() {
|
||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
|
||||
@@ -126,7 +128,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
GLM
|
||||
}
|
||||
|
||||
private static ChatModelConfig getLLMConfig(LLMType type) {
|
||||
protected ChatModelConfig getLLMConfig(LLMType type) {
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
|
||||
Reference in New Issue
Block a user