mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
(improvement)(headless)Commit new impl of SqlGenStrategy
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -15,12 +15,7 @@ import org.springframework.stereotype.Service;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
|
||||
@Service
|
||||
@@ -29,46 +24,75 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||
List<List<Map<String, String>>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
//2.generate sql generation prompt for each self-consistency inference
|
||||
Map<Prompt, List<Map<String, String>>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<Map<String, String>> exemplars : exemplarsList) {
|
||||
Prompt prompt = generatePrompt(llmReq, exemplars);
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
}
|
||||
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
exemplarRecallNumber);
|
||||
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
|
||||
fewShotNumber, selfConsistencyNumber);
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
|
||||
List<String> linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true);
|
||||
List<String> llmResults = new CopyOnWriteArrayList<>();
|
||||
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
||||
.apply(new HashMap<>());
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
//3.perform multiple self-consistency inferences parallelly
|
||||
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
|
||||
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
llmResults.add(result);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result);
|
||||
prompt2Output.put(prompt, result);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
||||
}
|
||||
);
|
||||
//3.format response.
|
||||
List<String> sqlList = llmResults.stream()
|
||||
.map(OutputFormat::getSql).collect(Collectors.toList());
|
||||
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
|
||||
Lists.newArrayList(prompt2Output.values()));
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
|
||||
return result;
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(LLMReq llmReq, List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = ""
|
||||
+ "#Role: You are a data analyst experienced in SQL languages.\n"
|
||||
+ "#Task: You will be provided a natural language query asked by business users,"
|
||||
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
|
||||
+ "by executing the SQL query against underlying database.\n"
|
||||
+ "#Rules:\n"
|
||||
+ "1.Always use `数据日期` as the date field.\n"
|
||||
+ "2.Always use `datediff` function to calculate date range.\n"
|
||||
+ "3.Only output SQL statement.\n"
|
||||
+ "#Exemplars:\n%s"
|
||||
+ "#UserQuery: %s "
|
||||
+ "#DatabaseMetadata: %s "
|
||||
+ "#SQL: ";
|
||||
|
||||
StringBuilder exemplarsStr = new StringBuilder();
|
||||
for (Map<String, String> example : fewshotExampleList) {
|
||||
String metadata = example.get("dbSchema");
|
||||
String question = example.get("questionAugmented");
|
||||
String sql = example.get("sql");
|
||||
String exemplarStr = String.format("#UserQuery: %s #DatabaseMetadata: %s #SQL: %s\n",
|
||||
question, metadata, sql);
|
||||
exemplarsStr.append(exemplarStr);
|
||||
}
|
||||
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dbSchema);
|
||||
|
||||
return PromptTemplate.from(promptStr).apply(new HashMap<>());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class PromptGenerator {
|
||||
|
||||
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction =
|
||||
"# Find the schema_links for generating SQL queries for each question based on the database schema "
|
||||
+ "and Foreign keys. Then use the the schema links to generate the "
|
||||
+ "SQL queries for each of the questions.";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT\nSQL: sql";
|
||||
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
public String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction = "# Find the schema_links for generating SQL queries for each question "
|
||||
+ "based on the database schema and Foreign keys.";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
|
||||
+ "SQL: sql";
|
||||
|
||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
|
||||
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
public List<String> generatePromptPool(LLMReq llmReq, List<List<Map<String, String>>> exampleListPool,
|
||||
boolean isSqlPrompt) {
|
||||
List<String> promptPool = new ArrayList<>();
|
||||
for (List<Map<String, String>> exampleList : exampleListPool) {
|
||||
String prompt;
|
||||
if (isSqlPrompt) {
|
||||
prompt = generatorLinkingAndSqlPrompt(llmReq, exampleList);
|
||||
} else {
|
||||
prompt = generateLinkingPrompt(llmReq, exampleList);
|
||||
}
|
||||
promptPool.add(prompt);
|
||||
}
|
||||
return promptPool;
|
||||
}
|
||||
|
||||
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
|
||||
int numSelfConsistency) {
|
||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
||||
for (int i = 0; i < numSelfConsistency; i++) {
|
||||
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
|
||||
Collections.shuffle(shuffledList);
|
||||
results.add(shuffledList.subList(0, numFewShots));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<LLMReq.ElementValue> linking = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
String priorExts = llmReq.getPriorExts();
|
||||
|
||||
String dbSchema = "Table: " + modelName + ", Columns = " + fieldNameList + "\nForeign_keys: []";
|
||||
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (ElementValue priorLinking : linking) {
|
||||
String fieldName = priorLinking.getFieldName();
|
||||
String fieldValue = priorLinking.getFieldValue();
|
||||
priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘");
|
||||
}
|
||||
String currentDataStr = "当前的日期是" + currentDate;
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String termStr = getTermStr(llmReq);
|
||||
String questionAugmented = String.format("%s (补充信息:%s . %s . %s) (备注: %s)", llmReq.getQueryText(),
|
||||
linkingListStr, currentDataStr, termStr, priorExts);
|
||||
return Pair.of(dbSchema, questionAugmented);
|
||||
}
|
||||
|
||||
private String getTermStr(LLMReq llmReq) {
|
||||
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
||||
StringBuilder termsDesc = new StringBuilder();
|
||||
if (!CollectionUtils.isEmpty(terms)) {
|
||||
termsDesc.append("相关业务术语:");
|
||||
for (int idx = 0; idx < terms.size(); idx++) {
|
||||
LLMReq.Term term = terms.get(idx);
|
||||
String name = term.getName();
|
||||
String description = term.getDescription();
|
||||
List<String> alias = term.getAlias();
|
||||
String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description);
|
||||
String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias);
|
||||
termsDesc.append(String.format("%d.<%s>是业务术语%s%s;", idx + 1, name, descPart, aliasPart));
|
||||
}
|
||||
if (termsDesc.length() > 0) {
|
||||
termsDesc.setLength(termsDesc.length() - 1);
|
||||
}
|
||||
}
|
||||
return termsDesc.toString();
|
||||
}
|
||||
|
||||
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
|
||||
List<List<Map<String, String>>> fewshotExampleListPool) {
|
||||
List<String> sqlPromptPool = new ArrayList<>();
|
||||
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
|
||||
String schemaLinkStr = schemaLinkStrPool.get(i);
|
||||
List<Map<String, String>> fewshotExampleList = fewshotExampleListPool.get(i);
|
||||
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
|
||||
sqlPromptPool.add(sqlPrompt);
|
||||
}
|
||||
return sqlPromptPool;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.headless.core.config.ParserConfig;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class PromptHelper {
|
||||
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
@Autowired
|
||||
private ExemplarManager exemplarManager;
|
||||
|
||||
public List<List<Map<String, String>>> getFewShotExemplars(LLMReq llmReq) {
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<Map<String, String>> exemplars = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
exemplarRecallNumber);
|
||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
||||
|
||||
// use random collection of exemplars for each self-consistency inference
|
||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||
List<Map<String, String>> shuffledList = new ArrayList<>(exemplars);
|
||||
Collections.shuffle(shuffledList);
|
||||
results.add(shuffledList.subList(0, fewShotNumber));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String tableName = llmReq.getSchema().getDataSetName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
String priorExts = llmReq.getPriorExts();
|
||||
|
||||
String dbSchema = "Table: " + tableName + ", Columns = " + fieldNameList;
|
||||
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (ElementValue value : linkedValues) {
|
||||
String fieldName = value.getFieldName();
|
||||
String fieldValue = value.getFieldValue();
|
||||
priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘");
|
||||
}
|
||||
String currentDataStr = "current date is " + currentDate;
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String termStr = getTermStr(llmReq);
|
||||
String questionAugmented = String.format("%s (补充信息:%s . %s . %s) (备注: %s)", llmReq.getQueryText(),
|
||||
linkingListStr, currentDataStr, termStr, priorExts);
|
||||
|
||||
return Pair.of(dbSchema, questionAugmented);
|
||||
}
|
||||
|
||||
private String getTermStr(LLMReq llmReq) {
|
||||
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
||||
StringBuilder termsDesc = new StringBuilder();
|
||||
if (!CollectionUtils.isEmpty(terms)) {
|
||||
termsDesc.append("相关业务术语:");
|
||||
for (int idx = 0; idx < terms.size(); idx++) {
|
||||
LLMReq.Term term = terms.get(idx);
|
||||
String name = term.getName();
|
||||
String description = term.getDescription();
|
||||
List<String> alias = term.getAlias();
|
||||
String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description);
|
||||
String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias);
|
||||
termsDesc.append(String.format("%d.<%s>是业务术语%s%s;", idx + 1, name, descPart, aliasPart));
|
||||
}
|
||||
if (termsDesc.length() > 0) {
|
||||
termsDesc.setLength(termsDesc.length() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
return termsDesc.toString();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.config.ParserConfig;
|
||||
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import org.slf4j.Logger;
|
||||
@@ -22,13 +21,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
|
||||
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@Autowired
|
||||
protected ExemplarManager exemplarManager;
|
||||
|
||||
@Autowired
|
||||
protected ParserConfig parserConfig;
|
||||
|
||||
@Autowired
|
||||
protected PromptGenerator promptGenerator;
|
||||
protected PromptHelper promptHelper;
|
||||
|
||||
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
||||
return S2ChatModelProvider.provide(llmConfig);
|
||||
|
||||
@@ -12,34 +12,31 @@ import dev.langchain4j.model.output.Response;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
|
||||
|
||||
@Service
|
||||
@Deprecated
|
||||
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
List<List<Map<String, String>>> exampleListPool = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
exemplarRecallNumber);
|
||||
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
|
||||
fewShotNumber, selfConsistencyNumber);
|
||||
//2.generate schema linking prompt for each self-consistency inference
|
||||
List<String> linkingPromptPool = new ArrayList<>();
|
||||
for (List<Map<String, String>> exampleList : exampleListPool) {
|
||||
String prompt = generateLinkingPrompt(llmReq, exampleList);
|
||||
linkingPromptPool.add(prompt);
|
||||
}
|
||||
|
||||
//2.generator linking prompt,and parallel generate response.
|
||||
List<String> linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
||||
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
@@ -53,8 +50,17 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
}
|
||||
);
|
||||
List<String> sortedList = OutputFormat.formatList(linkingResults);
|
||||
//3.generator sql prompt,and parallel generate response.
|
||||
List<String> sqlPromptPool = promptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
|
||||
|
||||
//3.generate sql generation prompt for each self-consistency inference
|
||||
List<String> sqlPromptPool = new ArrayList<>();
|
||||
for (int i = 0; i < sortedList.size(); i++) {
|
||||
String schemaLinkStr = sortedList.get(i);
|
||||
List<Map<String, String>> fewshotExampleList = exampleListPool.get(i);
|
||||
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
|
||||
sqlPromptPool.add(sqlPrompt);
|
||||
}
|
||||
|
||||
//4.perform multiple self-consistency inferences parallelly
|
||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||
@@ -64,15 +70,49 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
|
||||
sqlTaskPool.add(result);
|
||||
});
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
|
||||
//5.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
|
||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exampleListPool.get(0), sqlMapPair.getRight()));
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
private String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction = "# Find the schema_links for generating SQL queries for each question "
|
||||
+ "based on the database schema and Foreign keys.";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
private String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr,
|
||||
List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
|
||||
+ "SQL: sql";
|
||||
|
||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
|
||||
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
|
||||
@@ -68,7 +68,7 @@ public class LLMReq {
|
||||
}
|
||||
|
||||
public enum SqlGenType {
|
||||
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
|
||||
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency"),
|
||||
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
|
||||
|
||||
private String name;
|
||||
|
||||
@@ -13,12 +13,12 @@ import java.util.List;
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter PARSER_STRATEGY_TYPE =
|
||||
new Parameter("s2.parser.strategy", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY",
|
||||
new Parameter("s2.parser.strategy", "ONE_PASS_SELF_CONSISTENCY",
|
||||
"LLM解析生成S2SQL策略",
|
||||
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
|
||||
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql"
|
||||
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql",
|
||||
"list", "Parser相关配置", Lists.newArrayList(
|
||||
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
|
||||
"ONE_PASS_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
|
||||
|
||||
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
|
||||
new Parameter("s2.parser.linking.value.enable", "true",
|
||||
@@ -48,7 +48,7 @@ public class ParserConfig extends ParameterConfig {
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_FEW_SHOT_NUMBER =
|
||||
new Parameter("s2.parser.few-shot.number", "5",
|
||||
new Parameter("s2.parser.few-shot.number", "3",
|
||||
"few-shot样例个数", "样例越多效果可能越好,但token消耗越大",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
@@ -70,11 +70,7 @@ public class ParserConfig extends ParameterConfig {
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
PARSER_STRATEGY_TYPE,
|
||||
PARSER_LINKING_VALUE_ENABLE,
|
||||
PARSER_TEXT_LENGTH_THRESHOLD,
|
||||
PARSER_TEXT_LENGTH_THRESHOLD_SHORT,
|
||||
PARSER_TEXT_LENGTH_THRESHOLD_LONG,
|
||||
PARSER_FEW_SHOT_NUMBER,
|
||||
PARSER_SELF_CONSISTENCY_NUMBER,
|
||||
PARSER_SHOW_COUNT
|
||||
|
||||
Reference in New Issue
Block a user