mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless)Remove unused SqlGenStrategy
This commit is contained in:
@@ -13,7 +13,4 @@ public class Exemplar {
|
||||
|
||||
private String sql;
|
||||
|
||||
private String generatedSchemaLinkingCoT;
|
||||
|
||||
private String generatedSchemaLinkings;
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
@@ -14,6 +15,8 @@ import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@@ -28,22 +31,29 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class ExemplarManager {
|
||||
@Order(0)
|
||||
public class ExemplarManager implements CommandLineRunner {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json";
|
||||
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
|
||||
};
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<Exemplar> getExemplars() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
|
||||
};
|
||||
|
||||
@Override
|
||||
public void run(String... args) {
|
||||
try {
|
||||
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
|
||||
loadDefaultExemplars();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to init examples", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
|
||||
@@ -79,4 +89,13 @@ public class ExemplarManager {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private void loadDefaultExemplars() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
List<Exemplar> examples = JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
addExemplars(examples, collectionName);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,6 +23,21 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a data analyst experienced in SQL languages.\n"
|
||||
+ "#Task: You will be provided a natural language query asked by business users,"
|
||||
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
|
||||
+ "by executing the SQL query against underlying database.\n"
|
||||
+ "#Rules:"
|
||||
+ "1.ALWAYS use `数据日期` as the date field."
|
||||
+ "2.ALWAYS use `datediff()` as the date function."
|
||||
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query."
|
||||
+ "4.ONLY respond with the converted SQL statement.\n"
|
||||
+ "#Exemplars:\n%s"
|
||||
+ "#UserQuery: %s "
|
||||
+ "#Schema: %s "
|
||||
+ "#SQL: ";
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.recall exemplars
|
||||
@@ -60,35 +75,19 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(LLMReq llmReq, List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = ""
|
||||
+ "#Role: You are a data analyst experienced in SQL languages.\n"
|
||||
+ "#Task: You will be provided a natural language query asked by business users,"
|
||||
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
|
||||
+ "by executing the SQL query against underlying database.\n"
|
||||
+ "#Rules:"
|
||||
+ "1.ALWAYS use `数据日期` as the date field."
|
||||
+ "2.ALWAYS use `datediff()` as the date function."
|
||||
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query."
|
||||
+ "4.ONLY respond with the converted SQL statement.\n"
|
||||
+ "#Exemplars:\n%s"
|
||||
+ "#UserQuery: %s "
|
||||
+ "#DatabaseMetadata: %s "
|
||||
+ "#SQL: ";
|
||||
|
||||
StringBuilder exemplarsStr = new StringBuilder();
|
||||
for (Map<String, String> example : fewshotExampleList) {
|
||||
String metadata = example.get("dbSchema");
|
||||
String question = example.get("questionAugmented");
|
||||
String sql = example.get("sql");
|
||||
String exemplarStr = String.format("#UserQuery: %s #DatabaseMetadata: %s #SQL: %s\n",
|
||||
String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n",
|
||||
question, metadata, sql);
|
||||
exemplarsStr.append(exemplarStr);
|
||||
}
|
||||
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq);
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dataSemanticsStr);
|
||||
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
|
||||
String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr);
|
||||
|
||||
return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -47,15 +46,11 @@ public class PromptHelper {
|
||||
return results;
|
||||
}
|
||||
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String tableName = llmReq.getSchema().getDataSetName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
public String buildAugmentedQuestion(LLMReq llmReq) {
|
||||
List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
String priorExts = llmReq.getPriorExts();
|
||||
|
||||
String dbSchema = "Table: " + tableName + ", Columns = " + fieldNameList;
|
||||
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (LLMReq.ElementValue value : linkedValues) {
|
||||
String fieldName = value.getFieldName();
|
||||
@@ -65,10 +60,8 @@ public class PromptHelper {
|
||||
String currentDataStr = "当前的日期是" + currentDate;
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String termStr = buildTermStr(llmReq);
|
||||
String questionAugmented = String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(),
|
||||
return String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(),
|
||||
linkingListStr, currentDataStr, termStr, priorExts);
|
||||
|
||||
return Pair.of(dbSchema, questionAugmented);
|
||||
}
|
||||
|
||||
public String buildMetadataStr(LLMReq llmReq) {
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@Order(0)
|
||||
public class SqlEmbeddingListener implements CommandLineRunner {
|
||||
|
||||
@Autowired
|
||||
private ExemplarManager exemplarManager;
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Override
|
||||
public void run(String... args) {
|
||||
initSqlExamples();
|
||||
}
|
||||
|
||||
public void initSqlExamples() {
|
||||
try {
|
||||
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
|
||||
List<Exemplar> exemplars = exemplarManager.getExemplars();
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
exemplarManager.addExemplars(exemplars, collectionName);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("initSqlExamples error", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Service
|
||||
@Deprecated
|
||||
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
//2.generate schema linking prompt for each self-consistency inference
|
||||
List<String> linkingPromptPool = new ArrayList<>();
|
||||
for (List<Map<String, String>> exampleList : exampleListPool) {
|
||||
String prompt = generateLinkingPrompt(llmReq, exampleList);
|
||||
linkingPromptPool.add(prompt);
|
||||
}
|
||||
|
||||
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
linkingPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = linkingResult.content().text();
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result);
|
||||
linkingResults.add(OutputFormat.getSchemaLink(result));
|
||||
}
|
||||
);
|
||||
List<String> sortedList = OutputFormat.formatList(linkingResults);
|
||||
|
||||
//3.generate sql generation prompt for each self-consistency inference
|
||||
List<String> sqlPromptPool = new ArrayList<>();
|
||||
for (int i = 0; i < sortedList.size(); i++) {
|
||||
String schemaLinkStr = sortedList.get(i);
|
||||
List<Map<String, String>> fewshotExampleList = exampleListPool.get(i);
|
||||
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
|
||||
sqlPromptPool.add(sqlPrompt);
|
||||
}
|
||||
|
||||
//4.perform multiple self-consistency inferences parallelly
|
||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
String result = sqlResult.content().text();
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
|
||||
sqlTaskPool.add(result);
|
||||
});
|
||||
|
||||
//5.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exampleListPool.get(0), sqlMapPair.getRight()));
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
private String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction = "# Find the schema_links for generating SQL queries for each question "
|
||||
+ "based on the database schema and Foreign keys.";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
private String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr,
|
||||
List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
|
||||
+ "SQL: sql";
|
||||
|
||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
|
||||
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user