(improvement)(headless)Remove unused SqlGenStrategy

This commit is contained in:
jerryjzhang
2024-06-25 21:22:57 +08:00
parent d4cc53acae
commit ca68c445c2
6 changed files with 46 additions and 196 deletions

View File

@@ -13,7 +13,4 @@ public class Exemplar {
private String sql;
private String generatedSchemaLinkingCoT;
private String generatedSchemaLinkings;
}

View File

@@ -5,6 +5,7 @@ import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval;
@@ -14,6 +15,8 @@ import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
@@ -28,22 +31,29 @@ import java.util.stream.Collectors;
@Slf4j
@Component
public class ExemplarManager {
@Order(0)
public class ExemplarManager implements CommandLineRunner {
private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json";
@Autowired
private EmbeddingService embeddingService;
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
};
@Autowired
private EmbeddingConfig embeddingConfig;
public List<Exemplar> getExemplars() throws IOException {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream();
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
};
@Override
public void run(String... args) {
try {
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
loadDefaultExemplars();
}
} catch (Exception e) {
log.error("Failed to init examples", e);
}
}
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
@@ -79,4 +89,13 @@ public class ExemplarManager {
}
return result;
}
private void loadDefaultExemplars() throws IOException {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream();
List<Exemplar> examples = JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
String collectionName = embeddingConfig.getText2sqlCollectionName();
addExemplars(examples, collectionName);
}
}

View File

@@ -23,6 +23,21 @@ import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private static final String INSTRUCTION = ""
+ "#Role: You are a data analyst experienced in SQL languages.\n"
+ "#Task: You will be provided a natural language query asked by business users,"
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
+ "by executing the SQL query against underlying database.\n"
+ "#Rules:"
+ "1.ALWAYS use `数据日期` as the date field."
+ "2.ALWAYS use `datediff()` as the date function."
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query."
+ "4.ONLY respond with the converted SQL statement.\n"
+ "#Exemplars:\n%s"
+ "#UserQuery: %s "
+ "#Schema: %s "
+ "#SQL: ";
@Override
public LLMResp generate(LLMReq llmReq) {
//1.recall exemplars
@@ -60,35 +75,19 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
}
private Prompt generatePrompt(LLMReq llmReq, List<Map<String, String>> fewshotExampleList) {
String instruction = ""
+ "#Role: You are a data analyst experienced in SQL languages.\n"
+ "#Task: You will be provided a natural language query asked by business users,"
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
+ "by executing the SQL query against underlying database.\n"
+ "#Rules:"
+ "1.ALWAYS use `数据日期` as the date field."
+ "2.ALWAYS use `datediff()` as the date function."
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query."
+ "4.ONLY respond with the converted SQL statement.\n"
+ "#Exemplars:\n%s"
+ "#UserQuery: %s "
+ "#DatabaseMetadata: %s "
+ "#SQL: ";
StringBuilder exemplarsStr = new StringBuilder();
for (Map<String, String> example : fewshotExampleList) {
String metadata = example.get("dbSchema");
String question = example.get("questionAugmented");
String sql = example.get("sql");
String exemplarStr = String.format("#UserQuery: %s #DatabaseMetadata: %s #SQL: %s\n",
String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n",
question, metadata, sql);
exemplarsStr.append(exemplarStr);
}
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq);
String questionAugmented = questionPrompt.getRight();
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dataSemanticsStr);
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr);
return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
}

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -47,15 +46,11 @@ public class PromptHelper {
return results;
}
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
String tableName = llmReq.getSchema().getDataSetName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
public String buildAugmentedQuestion(LLMReq llmReq) {
List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
String currentDate = llmReq.getCurrentDate();
String priorExts = llmReq.getPriorExts();
String dbSchema = "Table: " + tableName + ", Columns = " + fieldNameList;
List<String> priorLinkingList = new ArrayList<>();
for (LLMReq.ElementValue value : linkedValues) {
String fieldName = value.getFieldName();
@@ -65,10 +60,8 @@ public class PromptHelper {
String currentDataStr = "当前的日期是" + currentDate;
String linkingListStr = String.join("", priorLinkingList);
String termStr = buildTermStr(llmReq);
String questionAugmented = String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(),
return String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(),
linkingListStr, currentDataStr, termStr, priorExts);
return Pair.of(dbSchema, questionAugmented);
}
public String buildMetadataStr(LLMReq llmReq) {

View File

@@ -1,39 +0,0 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import java.util.List;
@Slf4j
@Component
@Order(0)
public class SqlEmbeddingListener implements CommandLineRunner {
@Autowired
private ExemplarManager exemplarManager;
@Autowired
private EmbeddingConfig embeddingConfig;
@Override
public void run(String... args) {
initSqlExamples();
}
public void initSqlExamples() {
try {
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
List<Exemplar> exemplars = exemplarManager.getExemplars();
String collectionName = embeddingConfig.getText2sqlCollectionName();
exemplarManager.addExemplars(exemplars, collectionName);
}
} catch (Exception e) {
log.error("initSqlExamples error", e);
}
}
}

View File

@@ -1,119 +0,0 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
@Service
@Deprecated
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.recall exemplars
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
List<List<Map<String, String>>> exampleListPool = promptHelper.getFewShotExemplars(llmReq);
//2.generate schema linking prompt for each self-consistency inference
List<String> linkingPromptPool = new ArrayList<>();
for (List<Map<String, String>> exampleList : exampleListPool) {
String prompt = generateLinkingPrompt(llmReq, exampleList);
linkingPromptPool.add(prompt);
}
List<String> linkingResults = new CopyOnWriteArrayList<>();
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
linkingPromptPool.parallelStream().forEach(
linkingPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage());
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
String result = linkingResult.content().text();
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result);
linkingResults.add(OutputFormat.getSchemaLink(result));
}
);
List<String> sortedList = OutputFormat.formatList(linkingResults);
//3.generate sql generation prompt for each self-consistency inference
List<String> sqlPromptPool = new ArrayList<>();
for (int i = 0; i < sortedList.size(); i++) {
String schemaLinkStr = sortedList.get(i);
List<Map<String, String>> fewshotExampleList = exampleListPool.get(i);
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
sqlPromptPool.add(sqlPrompt);
}
//4.perform multiple self-consistency inferences parallelly
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String result = sqlResult.content().text();
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
sqlTaskPool.add(result);
});
//5.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exampleListPool.get(0), sqlMapPair.getRight()));
return llmResp;
}
private String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
String instruction = "# Find the schema_links for generating SQL queries for each question "
+ "based on the database schema and Foreign keys.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nA: Lets think step by step. In the question \"%s\", we are asked:";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
}
private String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr,
List<Map<String, String>> fewshotExampleList) {
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
+ "SQL: sql";
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
}
}