(improvement)(headless)Commit new impl of SqlGenStrategy

This commit is contained in:
jerryjzhang
2024-06-05 10:33:04 +08:00
parent 91e27bcadb
commit 008c1c35d8
12 changed files with 318 additions and 324 deletions

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
@@ -15,12 +15,7 @@ import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
import java.util.concurrent.ConcurrentHashMap;
@Service
@@ -29,46 +24,75 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
//1.recall exemplars
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
List<List<Map<String, String>>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
//2.generate sql generation prompt for each self-consistency inference
Map<Prompt, List<Map<String, String>>> prompt2Exemplar = new HashMap<>();
for (List<Map<String, String>> exemplars : exemplarsList) {
Prompt prompt = generatePrompt(llmReq, exemplars);
prompt2Exemplar.put(prompt, exemplars);
}
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
fewShotNumber, selfConsistencyNumber);
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
List<String> linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true);
List<String> llmResults = new CopyOnWriteArrayList<>();
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>());
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
//3.perform multiple self-consistency inferences parallelly
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
llmResults.add(result);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result);
prompt2Output.put(prompt, result);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
}
);
//3.format response.
List<String> sqlList = llmResults.stream()
.map(OutputFormat::getSql).collect(Collectors.toList());
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
//4.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
Lists.newArrayList(prompt2Output.values()));
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText());
result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
return result;
return llmResp;
}
private Prompt generatePrompt(LLMReq llmReq, List<Map<String, String>> fewshotExampleList) {
String instruction = ""
+ "#Role: You are a data analyst experienced in SQL languages.\n"
+ "#Task: You will be provided a natural language query asked by business users,"
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
+ "by executing the SQL query against underlying database.\n"
+ "#Rules:\n"
+ "1.Always use `数据日期` as the date field.\n"
+ "2.Always use `datediff` function to calculate date range.\n"
+ "3.Only output SQL statement.\n"
+ "#Exemplars:\n%s"
+ "#UserQuery: %s "
+ "#DatabaseMetadata: %s "
+ "#SQL: ";
StringBuilder exemplarsStr = new StringBuilder();
for (Map<String, String> example : fewshotExampleList) {
String metadata = example.get("dbSchema");
String question = example.get("questionAugmented");
String sql = example.get("sql");
String exemplarStr = String.format("#UserQuery: %s #DatabaseMetadata: %s #SQL: %s\n",
question, metadata, sql);
exemplarsStr.append(exemplarStr);
}
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dbSchema);
return PromptTemplate.from(promptStr).apply(new HashMap<>());
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
}
}

View File

@@ -1,156 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Component
@Slf4j
public class PromptGenerator {
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
String instruction =
"# Find the schema_links for generating SQL queries for each question based on the database schema "
+ "and Foreign keys. Then use the the schema links to generate the "
+ "SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT\nSQL: sql";
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nA: Lets think step by step. In the question \"%s\", we are asked:";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
}
public String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
String instruction = "# Find the schema_links for generating SQL queries for each question "
+ "based on the database schema and Foreign keys.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nA: Lets think step by step. In the question \"%s\", we are asked:";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
}
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
+ "SQL: sql";
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
}
public List<String> generatePromptPool(LLMReq llmReq, List<List<Map<String, String>>> exampleListPool,
boolean isSqlPrompt) {
List<String> promptPool = new ArrayList<>();
for (List<Map<String, String>> exampleList : exampleListPool) {
String prompt;
if (isSqlPrompt) {
prompt = generatorLinkingAndSqlPrompt(llmReq, exampleList);
} else {
prompt = generateLinkingPrompt(llmReq, exampleList);
}
promptPool.add(prompt);
}
return promptPool;
}
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
int numSelfConsistency) {
List<List<Map<String, String>>> results = new ArrayList<>();
for (int i = 0; i < numSelfConsistency; i++) {
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
Collections.shuffle(shuffledList);
results.add(shuffledList.subList(0, numFewShots));
}
return results;
}
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
String modelName = llmReq.getSchema().getDataSetName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<LLMReq.ElementValue> linking = llmReq.getLinking();
String currentDate = llmReq.getCurrentDate();
String priorExts = llmReq.getPriorExts();
String dbSchema = "Table: " + modelName + ", Columns = " + fieldNameList + "\nForeign_keys: []";
List<String> priorLinkingList = new ArrayList<>();
for (ElementValue priorLinking : linking) {
String fieldName = priorLinking.getFieldName();
String fieldValue = priorLinking.getFieldValue();
priorLinkingList.add("" + fieldValue + "‘是一个‘" + fieldName + "");
}
String currentDataStr = "当前的日期是" + currentDate;
String linkingListStr = String.join("", priorLinkingList);
String termStr = getTermStr(llmReq);
String questionAugmented = String.format("%s (补充信息:%s . %s . %s) (备注: %s)", llmReq.getQueryText(),
linkingListStr, currentDataStr, termStr, priorExts);
return Pair.of(dbSchema, questionAugmented);
}
private String getTermStr(LLMReq llmReq) {
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
StringBuilder termsDesc = new StringBuilder();
if (!CollectionUtils.isEmpty(terms)) {
termsDesc.append("相关业务术语:");
for (int idx = 0; idx < terms.size(); idx++) {
LLMReq.Term term = terms.get(idx);
String name = term.getName();
String description = term.getDescription();
List<String> alias = term.getAlias();
String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description);
String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias);
termsDesc.append(String.format("%d.<%s>是业务术语%s%s", idx + 1, name, descPart, aliasPart));
}
if (termsDesc.length() > 0) {
termsDesc.setLength(termsDesc.length() - 1);
}
}
return termsDesc.toString();
}
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
List<List<Map<String, String>>> fewshotExampleListPool) {
List<String> sqlPromptPool = new ArrayList<>();
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
String schemaLinkStr = schemaLinkStrPool.get(i);
List<Map<String, String>> fewshotExampleList = fewshotExampleListPool.get(i);
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
sqlPromptPool.add(sqlPrompt);
}
return sqlPromptPool;
}
}

View File

@@ -0,0 +1,97 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
@Component
@Slf4j
public class PromptHelper {
@Autowired
private ParserConfig parserConfig;
@Autowired
private ExemplarManager exemplarManager;
public List<List<Map<String, String>>> getFewShotExemplars(LLMReq llmReq) {
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<Map<String, String>> exemplars = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> results = new ArrayList<>();
// use random collection of exemplars for each self-consistency inference
for (int i = 0; i < selfConsistencyNumber; i++) {
List<Map<String, String>> shuffledList = new ArrayList<>(exemplars);
Collections.shuffle(shuffledList);
results.add(shuffledList.subList(0, fewShotNumber));
}
return results;
}
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
String tableName = llmReq.getSchema().getDataSetName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
String currentDate = llmReq.getCurrentDate();
String priorExts = llmReq.getPriorExts();
String dbSchema = "Table: " + tableName + ", Columns = " + fieldNameList;
List<String> priorLinkingList = new ArrayList<>();
for (ElementValue value : linkedValues) {
String fieldName = value.getFieldName();
String fieldValue = value.getFieldValue();
priorLinkingList.add("" + fieldValue + "‘是一个‘" + fieldName + "");
}
String currentDataStr = "current date is " + currentDate;
String linkingListStr = String.join("", priorLinkingList);
String termStr = getTermStr(llmReq);
String questionAugmented = String.format("%s (补充信息:%s . %s . %s) (备注: %s)", llmReq.getQueryText(),
linkingListStr, currentDataStr, termStr, priorExts);
return Pair.of(dbSchema, questionAugmented);
}
private String getTermStr(LLMReq llmReq) {
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
StringBuilder termsDesc = new StringBuilder();
if (!CollectionUtils.isEmpty(terms)) {
termsDesc.append("相关业务术语:");
for (int idx = 0; idx < terms.size(); idx++) {
LLMReq.Term term = terms.get(idx);
String name = term.getName();
String description = term.getDescription();
List<String> alias = term.getAlias();
String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description);
String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias);
termsDesc.append(String.format("%d.<%s>是业务术语%s%s", idx + 1, name, descPart, aliasPart));
}
if (termsDesc.length() > 0) {
termsDesc.setLength(termsDesc.length() - 1);
}
}
return termsDesc.toString();
}
}

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.slf4j.Logger;
@@ -22,13 +21,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
protected ExemplarManager exemplarManager;
@Autowired
protected ParserConfig parserConfig;
@Autowired
protected PromptGenerator promptGenerator;
protected PromptHelper promptHelper;
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
return S2ChatModelProvider.provide(llmConfig);

View File

@@ -12,34 +12,31 @@ import dev.langchain4j.model.output.Response;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
@Service
@Deprecated
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
//1.recall exemplars
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<List<Map<String, String>>> exampleListPool = promptHelper.getFewShotExemplars(llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
fewShotNumber, selfConsistencyNumber);
//2.generate schema linking prompt for each self-consistency inference
List<String> linkingPromptPool = new ArrayList<>();
for (List<Map<String, String>> exampleList : exampleListPool) {
String prompt = generateLinkingPrompt(llmReq, exampleList);
linkingPromptPool.add(prompt);
}
//2.generator linking prompt,and parallel generate response.
List<String> linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false);
List<String> linkingResults = new CopyOnWriteArrayList<>();
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
linkingPromptPool.parallelStream().forEach(
@@ -53,8 +50,17 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
}
);
List<String> sortedList = OutputFormat.formatList(linkingResults);
//3.generator sql prompt,and parallel generate response.
List<String> sqlPromptPool = promptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
//3.generate sql generation prompt for each self-consistency inference
List<String> sqlPromptPool = new ArrayList<>();
for (int i = 0; i < sortedList.size(); i++) {
String schemaLinkStr = sortedList.get(i);
List<Map<String, String>> fewshotExampleList = exampleListPool.get(i);
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
sqlPromptPool.add(sqlPrompt);
}
//4.perform multiple self-consistency inferences parallelly
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
@@ -64,15 +70,49 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
sqlTaskPool.add(result);
});
//4.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
//5.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exampleListPool.get(0), sqlMapPair.getRight()));
return llmResp;
}
private String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
String instruction = "# Find the schema_links for generating SQL queries for each question "
+ "based on the database schema and Foreign keys.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nA: Lets think step by step. In the question \"%s\", we are asked:";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
}
private String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr,
List<Map<String, String>> fewshotExampleList) {
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
+ "SQL: sql";
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);

View File

@@ -68,7 +68,7 @@ public class LLMReq {
}
public enum SqlGenType {
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency"),
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
private String name;

View File

@@ -13,12 +13,12 @@ import java.util.List;
public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_STRATEGY_TYPE =
new Parameter("s2.parser.strategy", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY",
new Parameter("s2.parser.strategy", "ONE_PASS_SELF_CONSISTENCY",
"LLM解析生成S2SQL策略",
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql"
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql",
"list", "Parser相关配置", Lists.newArrayList(
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
"ONE_PASS_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
new Parameter("s2.parser.linking.value.enable", "true",
@@ -48,7 +48,7 @@ public class ParserConfig extends ParameterConfig {
"number", "Parser相关配置");
public static final Parameter PARSER_FEW_SHOT_NUMBER =
new Parameter("s2.parser.few-shot.number", "5",
new Parameter("s2.parser.few-shot.number", "3",
"few-shot样例个数", "样例越多效果可能越好但token消耗越大",
"number", "Parser相关配置");
@@ -70,11 +70,7 @@ public class ParserConfig extends ParameterConfig {
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
PARSER_STRATEGY_TYPE,
PARSER_LINKING_VALUE_ENABLE,
PARSER_TEXT_LENGTH_THRESHOLD,
PARSER_TEXT_LENGTH_THRESHOLD_SHORT,
PARSER_TEXT_LENGTH_THRESHOLD_LONG,
PARSER_FEW_SHOT_NUMBER,
PARSER_SELF_CONSISTENCY_NUMBER,
PARSER_SHOW_COUNT