From ca68c445c234f613a0024d4052a95268ca22f57c Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Tue, 25 Jun 2024 21:22:57 +0800 Subject: [PATCH] (improvement)(headless)Remove unused SqlGenStrategy --- .../headless/chat/parser/llm/Exemplar.java | 3 - .../chat/parser/llm/ExemplarManager.java | 33 +++-- .../parser/llm/OnePassSCSqlGenStrategy.java | 37 +++--- .../chat/parser/llm/PromptHelper.java | 11 +- .../chat/parser/llm/SqlEmbeddingListener.java | 39 ------ .../parser/llm/TwoPassSCSqlGenStrategy.java | 119 ------------------ 6 files changed, 46 insertions(+), 196 deletions(-) delete mode 100644 headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlEmbeddingListener.java delete mode 100644 headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/TwoPassSCSqlGenStrategy.java diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/Exemplar.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/Exemplar.java index 7d924c317..7c64f9a1a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/Exemplar.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/Exemplar.java @@ -13,7 +13,4 @@ public class Exemplar { private String sql; - private String generatedSchemaLinkingCoT; - - private String generatedSchemaLinkings; } \ No newline at end of file diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java index ea54fe6bd..cb90965f6 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.headless.chat.utils.ComponentFactory; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.Retrieval; @@ -14,6 +15,8 @@ import dev.langchain4j.store.embedding.TextSegmentConvert; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.CommandLineRunner; +import org.springframework.core.annotation.Order; import org.springframework.core.io.ClassPathResource; import org.springframework.stereotype.Component; @@ -28,22 +31,29 @@ import java.util.stream.Collectors; @Slf4j @Component -public class ExemplarManager { +@Order(0) +public class ExemplarManager implements CommandLineRunner { private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json"; @Autowired private EmbeddingService embeddingService; - private TypeReference> valueTypeRef = new TypeReference>() { - }; @Autowired private EmbeddingConfig embeddingConfig; - public List getExemplars() throws IOException { - ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); - InputStream inputStream = resource.getInputStream(); - return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef); + private TypeReference> valueTypeRef = new TypeReference>() { + }; + + @Override + public void run(String... args) { + try { + if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) { + loadDefaultExemplars(); + } + } catch (Exception e) { + log.error("Failed to init examples", e); + } } public void addExemplars(List exemplars, String collectionName) { @@ -79,4 +89,13 @@ public class ExemplarManager { } return result; } + + private void loadDefaultExemplars() throws IOException { + ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); + InputStream inputStream = resource.getInputStream(); + List examples = JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef); + String collectionName = embeddingConfig.getText2sqlCollectionName(); + addExemplars(examples, collectionName); + } + } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index 0f2f49130..86b3d7537 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -23,6 +23,21 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + private static final String INSTRUCTION = "" + + "#Role: You are a data analyst experienced in SQL languages.\n" + + "#Task: You will be provided a natural language query asked by business users," + + "please convert it to a SQL query so that relevant answer could be returned to the user " + + "by executing the SQL query against underlying database.\n" + + "#Rules:" + + "1.ALWAYS use `数据日期` as the date field." + + "2.ALWAYS use `datediff()` as the date function." + + "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query." + + "4.ONLY respond with the converted SQL statement.\n" + + "#Exemplars:\n%s" + + "#UserQuery: %s " + + "#Schema: %s " + + "#SQL: "; + @Override public LLMResp generate(LLMReq llmReq) { //1.recall exemplars @@ -60,35 +75,19 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { } private Prompt generatePrompt(LLMReq llmReq, List> fewshotExampleList) { - String instruction = "" - + "#Role: You are a data analyst experienced in SQL languages.\n" - + "#Task: You will be provided a natural language query asked by business users," - + "please convert it to a SQL query so that relevant answer could be returned to the user " - + "by executing the SQL query against underlying database.\n" - + "#Rules:" - + "1.ALWAYS use `数据日期` as the date field." - + "2.ALWAYS use `datediff()` as the date function." - + "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query." - + "4.ONLY respond with the converted SQL statement.\n" - + "#Exemplars:\n%s" - + "#UserQuery: %s " - + "#DatabaseMetadata: %s " - + "#SQL: "; - StringBuilder exemplarsStr = new StringBuilder(); for (Map example : fewshotExampleList) { String metadata = example.get("dbSchema"); String question = example.get("questionAugmented"); String sql = example.get("sql"); - String exemplarStr = String.format("#UserQuery: %s #DatabaseMetadata: %s #SQL: %s\n", + String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n", question, metadata, sql); exemplarsStr.append(exemplarStr); } - Pair questionPrompt = promptHelper.transformQuestionPrompt(llmReq); String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq); - String questionAugmented = questionPrompt.getRight(); - String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dataSemanticsStr); + String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq); + String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr); return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index 23f5953b8..da2a83dc7 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -4,7 +4,6 @@ import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -47,15 +46,11 @@ public class PromptHelper { return results; } - public Pair transformQuestionPrompt(LLMReq llmReq) { - String tableName = llmReq.getSchema().getDataSetName(); - List fieldNameList = llmReq.getSchema().getFieldNameList(); + public String buildAugmentedQuestion(LLMReq llmReq) { List linkedValues = llmReq.getLinking(); String currentDate = llmReq.getCurrentDate(); String priorExts = llmReq.getPriorExts(); - String dbSchema = "Table: " + tableName + ", Columns = " + fieldNameList; - List priorLinkingList = new ArrayList<>(); for (LLMReq.ElementValue value : linkedValues) { String fieldName = value.getFieldName(); @@ -65,10 +60,8 @@ public class PromptHelper { String currentDataStr = "当前的日期是" + currentDate; String linkingListStr = String.join(",", priorLinkingList); String termStr = buildTermStr(llmReq); - String questionAugmented = String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(), + return String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(), linkingListStr, currentDataStr, termStr, priorExts); - - return Pair.of(dbSchema, questionAugmented); } public String buildMetadataStr(LLMReq llmReq) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlEmbeddingListener.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlEmbeddingListener.java deleted file mode 100644 index b5c3468eb..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlEmbeddingListener.java +++ /dev/null @@ -1,39 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.llm; - -import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.headless.chat.utils.ComponentFactory; -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.CommandLineRunner; -import org.springframework.core.annotation.Order; -import org.springframework.stereotype.Component; - -import java.util.List; - -@Slf4j -@Component -@Order(0) -public class SqlEmbeddingListener implements CommandLineRunner { - - @Autowired - private ExemplarManager exemplarManager; - @Autowired - private EmbeddingConfig embeddingConfig; - - @Override - public void run(String... args) { - initSqlExamples(); - } - - public void initSqlExamples() { - try { - if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) { - List exemplars = exemplarManager.getExemplars(); - String collectionName = embeddingConfig.getText2sqlCollectionName(); - exemplarManager.addExemplars(exemplars, collectionName); - } - } catch (Exception e) { - log.error("initSqlExamples error", e); - } - } -} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/TwoPassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/TwoPassSCSqlGenStrategy.java deleted file mode 100644 index c45667792..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/TwoPassSCSqlGenStrategy.java +++ /dev/null @@ -1,119 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.llm; - -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.input.Prompt; -import dev.langchain4j.model.input.PromptTemplate; -import dev.langchain4j.model.output.Response; -import org.apache.commons.lang3.tuple.Pair; -import org.springframework.stereotype.Service; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CopyOnWriteArrayList; - -@Service -@Deprecated -public class TwoPassSCSqlGenStrategy extends SqlGenStrategy { - - @Override - public LLMResp generate(LLMReq llmReq) { - //1.recall exemplars - keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq); - - List>> exampleListPool = promptHelper.getFewShotExemplars(llmReq); - - //2.generate schema linking prompt for each self-consistency inference - List linkingPromptPool = new ArrayList<>(); - for (List> exampleList : exampleListPool) { - String prompt = generateLinkingPrompt(llmReq, exampleList); - linkingPromptPool.add(prompt); - } - - List linkingResults = new CopyOnWriteArrayList<>(); - ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); - linkingPromptPool.parallelStream().forEach( - linkingPrompt -> { - Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>()); - keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage()); - Response linkingResult = chatLanguageModel.generate(prompt.toSystemMessage()); - String result = linkingResult.content().text(); - keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result); - linkingResults.add(OutputFormat.getSchemaLink(result)); - } - ); - List sortedList = OutputFormat.formatList(linkingResults); - - //3.generate sql generation prompt for each self-consistency inference - List sqlPromptPool = new ArrayList<>(); - for (int i = 0; i < sortedList.size(); i++) { - String schemaLinkStr = sortedList.get(i); - List> fewshotExampleList = exampleListPool.get(i); - String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList); - sqlPromptPool.add(sqlPrompt); - } - - //4.perform multiple self-consistency inferences parallelly - List sqlTaskPool = new CopyOnWriteArrayList<>(); - sqlPromptPool.parallelStream().forEach(sqlPrompt -> { - Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>()); - keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage()); - Response sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage()); - String result = sqlResult.content().text(); - keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result); - sqlTaskPool.add(result); - }); - - //5.format response. - Pair> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool); - LLMResp llmResp = new LLMResp(); - llmResp.setQuery(llmReq.getQueryText()); - //TODO: should use the same few-shot exemplars as the one chose by self-consistency vote - llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exampleListPool.get(0), sqlMapPair.getRight())); - return llmResp; - } - - private String generateLinkingPrompt(LLMReq llmReq, List> exampleList) { - String instruction = "# Find the schema_links for generating SQL queries for each question " - + "based on the database schema and Foreign keys."; - - List exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT"); - String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT"; - String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList); - - Pair questionPrompt = promptHelper.transformQuestionPrompt(llmReq); - String dbSchema = questionPrompt.getLeft(); - String questionAugmented = questionPrompt.getRight(); - String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:"; - String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented); - - return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt; - } - - private String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, - List> fewshotExampleList) { - String instruction = "# Use the the schema links to generate the SQL queries for each of the questions."; - List exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"); - String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n" - + "SQL: sql"; - - String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList); - Pair questionPrompt = promptHelper.transformQuestionPrompt(llmReq); - String dbSchema = questionPrompt.getLeft(); - String questionAugmented = questionPrompt.getRight(); - String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: "; - String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr); - return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt; - } - - @Override - public void afterPropertiesSet() { - SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this); - } -}