mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-27 20:42:33 +08:00
(improvement)(headless)Optimize Text2SQL prompt, explicitly ask LLM not hallucinate columns.
This commit is contained in:
@@ -16,6 +16,8 @@ import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
@Disabled
|
||||
public class Text2SQLEval extends BaseTest {
|
||||
@@ -91,13 +93,13 @@ public class Text2SQLEval extends BaseTest {
|
||||
@Test
|
||||
public void test_term() throws Exception {
|
||||
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().contains("用户");
|
||||
assert result.getQueryColumns().get(1).getName().contains("停留时长");
|
||||
assert result.getQueryResults().size() == 2;
|
||||
assert result.getQueryColumns().size() >= 1;
|
||||
assert result.getQueryColumns().stream()
|
||||
.filter(c -> c.getName().contains("停留时长")).collect(Collectors.toList()).size() == 1;
|
||||
assert result.getQueryResults().size() >= 1;
|
||||
}
|
||||
|
||||
public static Agent getLLMAgent(boolean enableMultiturn) {
|
||||
public Agent getLLMAgent(boolean enableMultiturn) {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
@@ -110,7 +112,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
return agent;
|
||||
}
|
||||
|
||||
private static RuleParserTool getLLMQueryTool() {
|
||||
private RuleParserTool getLLMQueryTool() {
|
||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
|
||||
@@ -126,7 +128,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
GLM
|
||||
}
|
||||
|
||||
private static ChatModelConfig getLLMConfig(LLMType type) {
|
||||
protected ChatModelConfig getLLMConfig(LLMType type) {
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
|
||||
Reference in New Issue
Block a user