mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[improvement](chat) Support new four methods of generating SQL using SqlGeneration large models. (#519)
This commit is contained in:
@@ -63,10 +63,10 @@ public class OptimizationConfig {
|
||||
@Value("${text2sql.example.num:10}")
|
||||
private int text2sqlExampleNum;
|
||||
|
||||
@Value("${text2sql.fewShots.num:10}")
|
||||
@Value("${text2sql.fewShots.num:5}")
|
||||
private int text2sqlFewShotsNum;
|
||||
|
||||
@Value("${text2sql.self.consistency.num:5}")
|
||||
@Value("${text2sql.self.consistency.num:2}")
|
||||
private int text2sqlSelfConsistencyNum;
|
||||
|
||||
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
||||
|
||||
@@ -4,13 +4,14 @@ import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionPromptGenerator;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
@@ -36,12 +37,12 @@ public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode());
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
String sql = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
result.setSqlOutput(sql);
|
||||
result.setSqlWeight(sqlWeight);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
|
||||
List<String> linkingSqlPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, true);
|
||||
List<String> llmResults = new CopyOnWriteArrayList<>();
|
||||
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
||||
.apply(new HashMap<>());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
llmResults.add(response.content().text());
|
||||
}
|
||||
);
|
||||
//3.format response.
|
||||
List<String> schemaLinkingResults = llmResults.stream()
|
||||
.map(llmResult -> OutputFormat.getSchemaLinks(llmResult)).collect(Collectors.toList());
|
||||
List<String> candidateSortedList = OutputFormat.formatList(schemaLinkingResults);
|
||||
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
|
||||
List<String> sqlList = llmResults.stream()
|
||||
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
|
||||
List<String> sqlListSortedList = OutputFormat.formatList(sqlList);
|
||||
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlListSortedList);
|
||||
log.info("linkingMap result:{},sqlMap:{}", linkingMap, sqlMap);
|
||||
return sqlMap.getRight();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
|
||||
//3.format response.
|
||||
String llmResult = response.content().text();
|
||||
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
|
||||
String sql = OutputFormat.getSql(response.content().text());
|
||||
Map<String, Double> sqlMap = new HashMap<>();
|
||||
sqlMap.put(sql, 1D);
|
||||
log.info("llmResult:{},schemaLinkStr:{},sql:{}", llmResult, schemaLinkStr, sql);
|
||||
return sqlMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OneStepsSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Override
|
||||
public String generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//TODO
|
||||
return "";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_STEP_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,15 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
/***
|
||||
* output format
|
||||
@@ -15,22 +20,96 @@ public class OutputFormat {
|
||||
|
||||
public static final String PATTERN = "\\{[^{}]+\\}";
|
||||
|
||||
public static String schemaLinkParse(String schemaLinkOutput) {
|
||||
public static String getSchemaLink(String schemaLink) {
|
||||
String reult = "";
|
||||
try {
|
||||
schemaLinkOutput = schemaLinkOutput.trim();
|
||||
reult = schemaLink.trim();
|
||||
String pattern = "Schema_links:(.*)";
|
||||
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
|
||||
Matcher matcher = regexPattern.matcher(schemaLinkOutput);
|
||||
Matcher matcher = regexPattern.matcher(reult);
|
||||
if (matcher.find()) {
|
||||
schemaLinkOutput = matcher.group(1).trim();
|
||||
} else {
|
||||
schemaLinkOutput = "";
|
||||
return matcher.group(1).trim();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
schemaLinkOutput = "";
|
||||
}
|
||||
return schemaLinkOutput;
|
||||
return reult;
|
||||
}
|
||||
|
||||
public static String getSql(String sqlOutput) {
|
||||
String sql = "";
|
||||
try {
|
||||
sqlOutput = sqlOutput.trim();
|
||||
String pattern = "SQL:(.*)";
|
||||
Pattern regexPattern = Pattern.compile(pattern);
|
||||
Matcher matcher = regexPattern.matcher(sqlOutput);
|
||||
if (matcher.find()) {
|
||||
return matcher.group(1);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return sql;
|
||||
}
|
||||
|
||||
public static String getSchemaLinks(String text) {
|
||||
String schemaLinks = "";
|
||||
try {
|
||||
text = text.trim();
|
||||
String pattern = "Schema_links:(\\[.*?\\])|Schema_links: (\\[.*?\\])";
|
||||
Pattern regexPattern = Pattern.compile(pattern);
|
||||
Matcher matcher = regexPattern.matcher(text);
|
||||
|
||||
if (matcher.find()) {
|
||||
if (matcher.group(1) != null) {
|
||||
schemaLinks = matcher.group(1);
|
||||
} else if (matcher.group(2) != null) {
|
||||
schemaLinks = matcher.group(2);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
|
||||
return schemaLinks;
|
||||
}
|
||||
|
||||
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> inputList) {
|
||||
Map<String, Integer> inputCounts = new HashMap<>();
|
||||
for (String input : inputList) {
|
||||
inputCounts.put(input, inputCounts.getOrDefault(input, 0) + 1);
|
||||
}
|
||||
|
||||
String inputMax = null;
|
||||
int maxCount = 0;
|
||||
int inputSize = inputList.size();
|
||||
Map<String, Double> votePercentage = new HashMap<>();
|
||||
for (Map.Entry<String, Integer> entry : inputCounts.entrySet()) {
|
||||
String input = entry.getKey();
|
||||
int count = entry.getValue();
|
||||
if (count > maxCount) {
|
||||
inputMax = input;
|
||||
maxCount = count;
|
||||
}
|
||||
double percentage = (double) count / inputSize;
|
||||
votePercentage.put(input, percentage);
|
||||
}
|
||||
return Pair.of(inputMax, votePercentage);
|
||||
}
|
||||
|
||||
public static List<String> formatList(List<String> toFormatList) {
|
||||
List<String> results = new ArrayList<>();
|
||||
for (String toFormat : toFormatList) {
|
||||
List<String> items = new ArrayList<>();
|
||||
String[] split = toFormat.replace("[", "").replace("]", "").split(",");
|
||||
for (String item : split) {
|
||||
items.add(item.trim());
|
||||
}
|
||||
Collections.sort(items);
|
||||
String result = "[" + String.join(",", items) + "]";
|
||||
results.add(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
public static FunctionResp functionCallParse(String llmOutput) {
|
||||
@@ -47,7 +126,7 @@ public class OutputFormat {
|
||||
return resp;
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +1,19 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SqlExample {
|
||||
|
||||
@JsonProperty("currentDate")
|
||||
private String currentDate;
|
||||
|
||||
@JsonProperty("tableName")
|
||||
private String tableName;
|
||||
|
||||
@JsonProperty("fieldsList")
|
||||
private String fieldsList;
|
||||
|
||||
@JsonProperty("question")
|
||||
private String question;
|
||||
|
||||
@JsonProperty("priorSchemaLinks")
|
||||
private String priorSchemaLinks;
|
||||
private String questionAugmented;
|
||||
|
||||
@JsonProperty("analysis")
|
||||
private String analysis;
|
||||
private String dbSchema;
|
||||
|
||||
@JsonProperty("schemaLinks")
|
||||
private String schemaLinks;
|
||||
|
||||
@JsonProperty("sql")
|
||||
private String sql;
|
||||
|
||||
private String generatedSchemaLinkingCoT;
|
||||
|
||||
private String generatedSchemaLinkings;
|
||||
}
|
||||
@@ -26,7 +26,7 @@ import org.springframework.stereotype.Component;
|
||||
@Component
|
||||
public class SqlExampleLoader {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "example.json";
|
||||
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
||||
|
||||
@@ -2,12 +2,19 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Sql Generation
|
||||
* Sql Generation interface, generating SQL using a large model.
|
||||
*/
|
||||
public interface SqlGeneration {
|
||||
|
||||
String generation(LLMReq llmReq, String modelClusterKey);
|
||||
/***
|
||||
* generate SQL through LLMReq.
|
||||
* @param llmReq
|
||||
* @param modelClusterKey
|
||||
* @return
|
||||
*/
|
||||
Map<String, Double> generation(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,65 +1,131 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class SqlPromptGenerator {
|
||||
|
||||
public String generateSchemaLinkingPrompt(String question, String modelName, List<String> fieldsList,
|
||||
List<ElementValue> priorSchemaLinks, List<Map<String, String>> exampleList) {
|
||||
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction =
|
||||
"# Find the schema_links for generating SQL queries for each question based on the database schema "
|
||||
+ "and Foreign keys. Then use the the schema links to generate the "
|
||||
+ "SQL queries for each of the questions.";
|
||||
|
||||
String exampleTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
|
||||
+ "问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}";
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT\nSQL: sql";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("tableName", "fieldsList", "priorSchemaLinks", "question", "analysis",
|
||||
"schemaLinks");
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
|
||||
String newCaseTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
|
||||
+ "问题:{question}\n分析: 让我们一步一步地思考。";
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
String newCasePrompt = newCaseTemplate.replace("{tableName}", modelName)
|
||||
.replace("{fieldsList}", fieldsList.toString())
|
||||
.replace("{priorSchemaLinks}", getPriorSchemaLinks(priorSchemaLinks))
|
||||
.replace("{question}", question);
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
String instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links";
|
||||
public String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction = "# Find the schema_links for generating SQL queries for each question "
|
||||
+ "based on the database schema and Foreign keys.";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\n"
|
||||
+ "Schema_links: generatedSchemaLinkings\nSQL: {sql}";
|
||||
|
||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
|
||||
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
private String getPriorSchemaLinks(List<ElementValue> priorSchemaLinks) {
|
||||
return priorSchemaLinks.stream()
|
||||
.map(elementValue -> "'" + elementValue.getFieldName() + "'->" + elementValue.getFieldValue())
|
||||
.collect(Collectors.joining(",", "[", "]"));
|
||||
public List<String> generatePromptPool(LLMReq llmReq, List<List<Map<String, String>>> exampleListPool,
|
||||
boolean isSqlPrompt) {
|
||||
List<String> promptPool = new ArrayList<>();
|
||||
for (List<Map<String, String>> exampleList : exampleListPool) {
|
||||
String prompt;
|
||||
if (isSqlPrompt) {
|
||||
prompt = generatorLinkingAndSqlPrompt(llmReq, exampleList);
|
||||
} else {
|
||||
prompt = generateLinkingPrompt(llmReq, exampleList);
|
||||
}
|
||||
promptPool.add(prompt);
|
||||
}
|
||||
return promptPool;
|
||||
}
|
||||
|
||||
public String generateSqlPrompt(String question, String modelName, String schemaLinkStr, String dataDate,
|
||||
List<Map<String, String>> exampleList) {
|
||||
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
|
||||
int numSelfConsistency) {
|
||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
||||
for (int i = 0; i < numSelfConsistency; i++) {
|
||||
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
|
||||
Collections.shuffle(shuffledList);
|
||||
results.add(shuffledList.subList(0, numFewShots));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("question", "currentDate", "tableName", "schemaLinks", "sql");
|
||||
String exampleTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
|
||||
+ "Schema_links:{schemaLinks}\nSQL:{sql}";
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
String priorExts = llmReq.getPriorExts();
|
||||
|
||||
String sqlExamplePrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
String dbSchema = "Table: " + modelName + ", Columns = " + fieldNameList + "\nForeign_keys: []";
|
||||
|
||||
String newCaseTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
|
||||
+ "Schema_links:{schemaLinks}\nSQL:";
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (ElementValue priorLinking : linking) {
|
||||
String fieldName = priorLinking.getFieldName();
|
||||
String fieldValue = priorLinking.getFieldValue();
|
||||
priorLinkingList.add("'" + fieldValue + "'是一个'" + fieldName + "'");
|
||||
}
|
||||
String currentDataStr = "当前的日期是" + currentDate;
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String questionAugmented = String.format("%s (补充信息:%s 。 %s) (备注: %s)", llmReq.getQueryText(), linkingListStr,
|
||||
currentDataStr, priorExts);
|
||||
return Pair.of(dbSchema, questionAugmented);
|
||||
}
|
||||
|
||||
String newCasePrompt = newCaseTemplate.replace("{question}", question)
|
||||
.replace("{currentDate}", dataDate)
|
||||
.replace("{tableName}", modelName)
|
||||
.replace("{schemaLinks}", schemaLinkStr);
|
||||
|
||||
String instruction = "# 根据schema_links为每个问题生成SQL查询语句";
|
||||
return instruction + InputFormat.SEPERATOR + sqlExamplePrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
|
||||
List<List<Map<String, String>>> fewshotExampleListPool) {
|
||||
List<String> sqlPromptPool = new ArrayList<>();
|
||||
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
|
||||
String schemaLinkStr = schemaLinkStrPool.get(i);
|
||||
List<Map<String, String>> fewshotExampleList = fewshotExampleListPool.get(i);
|
||||
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
|
||||
sqlPromptPool.add(sqlPrompt);
|
||||
}
|
||||
return sqlPromptPool;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
|
||||
//2.generator linking prompt,and parallel generate response.
|
||||
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
||||
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
linkingPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = linkingResult.content().text();
|
||||
linkingResults.add(OutputFormat.getSchemaLink(result));
|
||||
}
|
||||
);
|
||||
List<String> sortedList = OutputFormat.formatList(linkingResults);
|
||||
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(sortedList);
|
||||
//3.generator sql prompt,and parallel generate response.
|
||||
List<String> sqlPromptPool = sqlPromptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
|
||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
String result = sqlResult.content().text();
|
||||
sqlTaskPool.add(result);
|
||||
});
|
||||
//4.format response.
|
||||
List<String> sqlTaskSortedList = OutputFormat.formatList(sqlTaskPool);
|
||||
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlTaskSortedList);
|
||||
log.info("linkingMap result:{},sqlMap:{}", linkingMap, sqlMap);
|
||||
return sqlMap.getRight();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -21,7 +20,7 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoStepSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
@@ -36,36 +35,29 @@ public class TwoStepSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public String generation(LLMReq llmReq, String modelClusterKey) {
|
||||
String text2sqlCollectionName = optimizationConfig.getText2sqlCollectionName();
|
||||
int text2sqlFewShotsNum = optimizationConfig.getText2sqlFewShotsNum();
|
||||
String queryText = llmReq.getQueryText();
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(queryText, text2sqlCollectionName,
|
||||
text2sqlFewShotsNum);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||
|
||||
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
|
||||
linking, sqlExamples);
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
|
||||
|
||||
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
|
||||
|
||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
|
||||
llmReq.getCurrentDate(), sqlExamples);
|
||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
|
||||
|
||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||
return sqlResult.content().text();
|
||||
Map<String, Double> sqlMap = new HashMap<>();
|
||||
sqlMap.put(sqlResult.content().text(), 1D);
|
||||
return sqlMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_STEP_AUTO_COT, this);
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoStepCSSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Override
|
||||
public String generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//TODO
|
||||
return "";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_STEP_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -19,9 +19,7 @@ public class LLMReq {
|
||||
|
||||
private String priorExts;
|
||||
|
||||
// FIXME: Currently Java code is not use AUTO_COT, only two step, but it is used in Python
|
||||
// code, we use it here just for compatibility. The Java code will be updated in the future.
|
||||
private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_STEP_AUTO_COT;
|
||||
private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_PASS_AUTO_COT;
|
||||
|
||||
@Data
|
||||
public static class ElementValue {
|
||||
@@ -51,13 +49,13 @@ public class LLMReq {
|
||||
|
||||
public enum SqlGenerationMode {
|
||||
|
||||
ONE_STEP_AUTO_COT("1_pass_auto_cot"),
|
||||
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
|
||||
|
||||
ONE_STEP_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
|
||||
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
|
||||
|
||||
TWO_STEP_AUTO_COT("2_pass_auto_cot"),
|
||||
TWO_PASS_AUTO_COT("2_pass_auto_cot"),
|
||||
|
||||
TWO_STEP_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
|
||||
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
|
||||
|
||||
|
||||
private String name;
|
||||
|
||||
Reference in New Issue
Block a user