[improvement](chat) Support new four methods of generating SQL using SqlGeneration large models. (#519)

This commit is contained in:
lexluo09
2023-12-16 14:44:56 +08:00
committed by GitHub
parent c86cd9f901
commit 9201550027
18 changed files with 700 additions and 460 deletions

View File

@@ -63,10 +63,10 @@ public class OptimizationConfig {
@Value("${text2sql.example.num:10}")
private int text2sqlExampleNum;
@Value("${text2sql.fewShots.num:10}")
@Value("${text2sql.fewShots.num:5}")
private int text2sqlFewShotsNum;
@Value("${text2sql.self.consistency.num:5}")
@Value("${text2sql.self.consistency.num:2}")
private int text2sqlSelfConsistencyNum;
@Value("${text2sql.collection.name:text2dsl_agent_collection}")

View File

@@ -4,13 +4,14 @@ import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionPromptGenerator;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@@ -36,12 +37,12 @@ public class JavaLLMProxy implements LLMProxy {
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode());
String modelName = llmReq.getSchema().getModelName();
String sql = sqlGeneration.generation(llmReq, modelClusterKey);
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
result.setSqlOutput(sql);
result.setSqlWeight(sqlWeight);
return result;
}

View File

@@ -0,0 +1,76 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExampleLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
@Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples and generate exampleListPool
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
List<String> linkingSqlPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, true);
List<String> llmResults = new CopyOnWriteArrayList<>();
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
llmResults.add(response.content().text());
}
);
//3.format response.
List<String> schemaLinkingResults = llmResults.stream()
.map(llmResult -> OutputFormat.getSchemaLinks(llmResult)).collect(Collectors.toList());
List<String> candidateSortedList = OutputFormat.formatList(schemaLinkingResults);
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
List<String> sqlList = llmResults.stream()
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
List<String> sqlListSortedList = OutputFormat.formatList(sqlList);
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlListSortedList);
log.info("linkingMap result:{},sqlMap:{}", linkingMap, sqlMap);
return sqlMap.getRight();
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
}
}

View File

@@ -0,0 +1,63 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExampleLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
@Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
//3.format response.
String llmResult = response.content().text();
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
String sql = OutputFormat.getSql(response.content().text());
Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(sql, 1D);
log.info("llmResult:{},schemaLinkStr:{},sql:{}", llmResult, schemaLinkStr, sql);
return sqlMap;
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT, this);
}
}

View File

@@ -1,24 +0,0 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class OneStepsSqlGeneration implements SqlGeneration, InitializingBean {
@Override
public String generation(LLMReq llmReq, String modelClusterKey) {
//TODO
return "";
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_STEP_AUTO_COT, this);
}
}

View File

@@ -2,10 +2,15 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
/***
* output format
@@ -15,22 +20,96 @@ public class OutputFormat {
public static final String PATTERN = "\\{[^{}]+\\}";
public static String schemaLinkParse(String schemaLinkOutput) {
public static String getSchemaLink(String schemaLink) {
String reult = "";
try {
schemaLinkOutput = schemaLinkOutput.trim();
reult = schemaLink.trim();
String pattern = "Schema_links:(.*)";
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher matcher = regexPattern.matcher(schemaLinkOutput);
Matcher matcher = regexPattern.matcher(reult);
if (matcher.find()) {
schemaLinkOutput = matcher.group(1).trim();
} else {
schemaLinkOutput = "";
return matcher.group(1).trim();
}
} catch (Exception e) {
log.error("", e);
schemaLinkOutput = "";
}
return schemaLinkOutput;
return reult;
}
public static String getSql(String sqlOutput) {
String sql = "";
try {
sqlOutput = sqlOutput.trim();
String pattern = "SQL:(.*)";
Pattern regexPattern = Pattern.compile(pattern);
Matcher matcher = regexPattern.matcher(sqlOutput);
if (matcher.find()) {
return matcher.group(1);
}
} catch (Exception e) {
log.error("", e);
}
return sql;
}
public static String getSchemaLinks(String text) {
String schemaLinks = "";
try {
text = text.trim();
String pattern = "Schema_links:(\\[.*?\\])|Schema_links: (\\[.*?\\])";
Pattern regexPattern = Pattern.compile(pattern);
Matcher matcher = regexPattern.matcher(text);
if (matcher.find()) {
if (matcher.group(1) != null) {
schemaLinks = matcher.group(1);
} else if (matcher.group(2) != null) {
schemaLinks = matcher.group(2);
}
}
} catch (Exception e) {
log.error("", e);
}
return schemaLinks;
}
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> inputList) {
Map<String, Integer> inputCounts = new HashMap<>();
for (String input : inputList) {
inputCounts.put(input, inputCounts.getOrDefault(input, 0) + 1);
}
String inputMax = null;
int maxCount = 0;
int inputSize = inputList.size();
Map<String, Double> votePercentage = new HashMap<>();
for (Map.Entry<String, Integer> entry : inputCounts.entrySet()) {
String input = entry.getKey();
int count = entry.getValue();
if (count > maxCount) {
inputMax = input;
maxCount = count;
}
double percentage = (double) count / inputSize;
votePercentage.put(input, percentage);
}
return Pair.of(inputMax, votePercentage);
}
public static List<String> formatList(List<String> toFormatList) {
List<String> results = new ArrayList<>();
for (String toFormat : toFormatList) {
List<String> items = new ArrayList<>();
String[] split = toFormat.replace("[", "").replace("]", "").split(",");
for (String item : split) {
items.add(item.trim());
}
Collections.sort(items);
String result = "[" + String.join(",", items) + "]";
results.add(result);
}
return results;
}
public static FunctionResp functionCallParse(String llmOutput) {
@@ -47,7 +126,7 @@ public class OutputFormat {
return resp;
} catch (Exception e) {
log.error("", e);
return null;
}
return null;
}
}

View File

@@ -1,32 +1,19 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data
public class SqlExample {
@JsonProperty("currentDate")
private String currentDate;
@JsonProperty("tableName")
private String tableName;
@JsonProperty("fieldsList")
private String fieldsList;
@JsonProperty("question")
private String question;
@JsonProperty("priorSchemaLinks")
private String priorSchemaLinks;
private String questionAugmented;
@JsonProperty("analysis")
private String analysis;
private String dbSchema;
@JsonProperty("schemaLinks")
private String schemaLinks;
@JsonProperty("sql")
private String sql;
private String generatedSchemaLinkingCoT;
private String generatedSchemaLinkings;
}

View File

@@ -26,7 +26,7 @@ import org.springframework.stereotype.Component;
@Component
public class SqlExampleLoader {
private static final String EXAMPLE_JSON_FILE = "example.json";
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {

View File

@@ -2,12 +2,19 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import java.util.Map;
/**
* Sql Generation
* Sql Generation interface, generating SQL using a large model.
*/
public interface SqlGeneration {
String generation(LLMReq llmReq, String modelClusterKey);
/***
* generate SQL through LLMReq.
* @param llmReq
* @param modelClusterKey
* @return
*/
Map<String, Double> generation(LLMReq llmReq, String modelClusterKey);
}

View File

@@ -1,65 +1,131 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class SqlPromptGenerator {
public String generateSchemaLinkingPrompt(String question, String modelName, List<String> fieldsList,
List<ElementValue> priorSchemaLinks, List<Map<String, String>> exampleList) {
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
String instruction =
"# Find the schema_links for generating SQL queries for each question based on the database schema "
+ "and Foreign keys. Then use the the schema links to generate the "
+ "SQL queries for each of the questions.";
String exampleTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
+ "问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT\nSQL: sql";
List<String> exampleKeys = Arrays.asList("tableName", "fieldsList", "priorSchemaLinks", "question", "analysis",
"schemaLinks");
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
+ "问题:{question}\n分析: 让我们一步一步地思考。";
String newCaseTemplate = "%s\nQ: %s\nA: Lets think step by step. In the question \"%s\", we are asked:";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
String newCasePrompt = newCaseTemplate.replace("{tableName}", modelName)
.replace("{fieldsList}", fieldsList.toString())
.replace("{priorSchemaLinks}", getPriorSchemaLinks(priorSchemaLinks))
.replace("{question}", question);
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
}
String instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links";
public String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
String instruction = "# Find the schema_links for generating SQL queries for each question "
+ "based on the database schema and Foreign keys.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nA: Lets think step by step. In the question \"%s\", we are asked:";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
}
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\n"
+ "Schema_links: generatedSchemaLinkings\nSQL: {sql}";
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
}
private String getPriorSchemaLinks(List<ElementValue> priorSchemaLinks) {
return priorSchemaLinks.stream()
.map(elementValue -> "'" + elementValue.getFieldName() + "'->" + elementValue.getFieldValue())
.collect(Collectors.joining(",", "[", "]"));
public List<String> generatePromptPool(LLMReq llmReq, List<List<Map<String, String>>> exampleListPool,
boolean isSqlPrompt) {
List<String> promptPool = new ArrayList<>();
for (List<Map<String, String>> exampleList : exampleListPool) {
String prompt;
if (isSqlPrompt) {
prompt = generatorLinkingAndSqlPrompt(llmReq, exampleList);
} else {
prompt = generateLinkingPrompt(llmReq, exampleList);
}
promptPool.add(prompt);
}
return promptPool;
}
public String generateSqlPrompt(String question, String modelName, String schemaLinkStr, String dataDate,
List<Map<String, String>> exampleList) {
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
int numSelfConsistency) {
List<List<Map<String, String>>> results = new ArrayList<>();
for (int i = 0; i < numSelfConsistency; i++) {
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
Collections.shuffle(shuffledList);
results.add(shuffledList.subList(0, numFewShots));
}
return results;
}
List<String> exampleKeys = Arrays.asList("question", "currentDate", "tableName", "schemaLinks", "sql");
String exampleTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
+ "Schema_links:{schemaLinks}\nSQL:{sql}";
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
String modelName = llmReq.getSchema().getModelName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<ElementValue> linking = llmReq.getLinking();
String currentDate = llmReq.getCurrentDate();
String priorExts = llmReq.getPriorExts();
String sqlExamplePrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
String dbSchema = "Table: " + modelName + ", Columns = " + fieldNameList + "\nForeign_keys: []";
String newCaseTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
+ "Schema_links:{schemaLinks}\nSQL:";
List<String> priorLinkingList = new ArrayList<>();
for (ElementValue priorLinking : linking) {
String fieldName = priorLinking.getFieldName();
String fieldValue = priorLinking.getFieldValue();
priorLinkingList.add("'" + fieldValue + "'是一个'" + fieldName + "'");
}
String currentDataStr = "当前的日期是" + currentDate;
String linkingListStr = String.join("", priorLinkingList);
String questionAugmented = String.format("%s (补充信息:%s 。 %s) (备注: %s)", llmReq.getQueryText(), linkingListStr,
currentDataStr, priorExts);
return Pair.of(dbSchema, questionAugmented);
}
String newCasePrompt = newCaseTemplate.replace("{question}", question)
.replace("{currentDate}", dataDate)
.replace("{tableName}", modelName)
.replace("{schemaLinks}", schemaLinkStr);
String instruction = "# 根据schema_links为每个问题生成SQL查询语句";
return instruction + InputFormat.SEPERATOR + sqlExamplePrompt + InputFormat.SEPERATOR + newCasePrompt;
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
List<List<Map<String, String>>> fewshotExampleListPool) {
List<String> sqlPromptPool = new ArrayList<>();
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
String schemaLinkStr = schemaLinkStrPool.get(i);
List<Map<String, String>> fewshotExampleList = fewshotExampleListPool.get(i);
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
sqlPromptPool.add(sqlPrompt);
}
return sqlPromptPool;
}
}

View File

@@ -0,0 +1,81 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExampleLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
@Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples and generate exampleListPool
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
//2.generator linking prompt,and parallel generate response.
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
List<String> linkingResults = new CopyOnWriteArrayList<>();
linkingPromptPool.parallelStream().forEach(
linkingPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
String result = linkingResult.content().text();
linkingResults.add(OutputFormat.getSchemaLink(result));
}
);
List<String> sortedList = OutputFormat.formatList(linkingResults);
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(sortedList);
//3.generator sql prompt,and parallel generate response.
List<String> sqlPromptPool = sqlPromptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String result = sqlResult.content().text();
sqlTaskPool.add(result);
});
//4.format response.
List<String> sqlTaskSortedList = OutputFormat.formatList(sqlTaskPool);
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlTaskSortedList);
log.info("linkingMap result:{},sqlMap:{}", linkingMap, sqlMap);
return sqlMap.getRight();
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
}
}

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
@@ -21,7 +20,7 @@ import org.springframework.stereotype.Service;
@Service
@Slf4j
public class TwoStepSqlGeneration implements SqlGeneration, InitializingBean {
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
@Autowired
private ChatLanguageModel chatLanguageModel;
@@ -36,36 +35,29 @@ public class TwoStepSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public String generation(LLMReq llmReq, String modelClusterKey) {
String text2sqlCollectionName = optimizationConfig.getText2sqlCollectionName();
int text2sqlFewShotsNum = optimizationConfig.getText2sqlFewShotsNum();
String queryText = llmReq.getQueryText();
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(queryText, text2sqlCollectionName,
text2sqlFewShotsNum);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
String modelName = llmReq.getSchema().getModelName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<ElementValue> linking = llmReq.getLinking();
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
linking, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
llmReq.getCurrentDate(), sqlExamples);
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
return sqlResult.content().text();
Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(sqlResult.content().text(), 1D);
return sqlMap;
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_STEP_AUTO_COT, this);
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT, this);
}
}

View File

@@ -1,24 +0,0 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class TwoStepCSSqlGeneration implements SqlGeneration, InitializingBean {
@Override
public String generation(LLMReq llmReq, String modelClusterKey) {
//TODO
return "";
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_STEP_AUTO_COT_SELF_CONSISTENCY, this);
}
}

View File

@@ -19,9 +19,7 @@ public class LLMReq {
private String priorExts;
// FIXME: Currently Java code is not use AUTO_COT, only two step, but it is used in Python
// code, we use it here just for compatibility. The Java code will be updated in the future.
private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_STEP_AUTO_COT;
private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_PASS_AUTO_COT;
@Data
public static class ElementValue {
@@ -51,13 +49,13 @@ public class LLMReq {
public enum SqlGenerationMode {
ONE_STEP_AUTO_COT("1_pass_auto_cot"),
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
ONE_STEP_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
TWO_STEP_AUTO_COT("2_pass_auto_cot"),
TWO_PASS_AUTO_COT("2_pass_auto_cot"),
TWO_STEP_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
private String name;