From 6fcd105249df0cd4f4fa9fce3981cdbdf8c3eb92 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Wed, 18 Dec 2024 12:04:06 +0800 Subject: [PATCH] [improvement][chat]Inject schema info into the prompt of LLMSqlCorrector. --- .../headless/chat/corrector/LLMSqlCorrector.java | 16 +++++++++------- .../translator/DefaultSemanticTranslator.java | 3 ++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java index a8ebb910a..368e2c05a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; @@ -34,11 +35,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { + "please take a review and help correct it if necessary." + "\n#Rules: " + "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`." + "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard." - + "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." - + "\n4.ALWAYS use `with` statement if nested aggregation is needed." - + "\n5.ALWAYS enclose alias declared by `AS` command in underscores." - + "\n6.Alias created by `AS` command must be in the same language ast the `Question`." - + "\n#Question:{{question}} #InputSQL:{{sql}} #Response:"; + + "\n3.SQL columns and values must be mentioned in the `#Schema`." + + "\n#Question:{{question}} #Schema:{{schema}} #InputSQL:{{sql}} #Response:"; public LLMSqlCorrector() { ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正") @@ -67,12 +65,15 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { return; } + Text2SQLExemplar exemplar = (Text2SQLExemplar)semanticParseInfo.getProperties() + .get(Text2SQLExemplar.PROPERTY_KEY); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig()); SemanticSqlExtractor extractor = AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(), - semanticParseInfo, chatApp.getPrompt()); + semanticParseInfo, chatApp.getPrompt(), exemplar); SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText()); keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql); if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) { @@ -81,10 +82,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { } private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo, - String promptTemplate) { + String promptTemplate, Text2SQLExemplar exemplar) { Map variable = new HashMap<>(); variable.put("question", queryText); variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); + variable.put("schema", exemplar.getDbSchema()); return PromptTemplate.from(promptTemplate).apply(variable); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java index afa4001fe..a67e51f81 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java @@ -46,7 +46,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator { optimizer.rewrite(queryStatement); } } - log.info("translated query SQL: [{}]", queryStatement.getSql()); + log.info("translated query SQL: [{}]", + StringUtils.normalizeSpace(queryStatement.getSql())); } catch (Exception e) { queryStatement.setErrMsg(e.getMessage()); log.error("Failed to translate query [{}]", e.getMessage(), e);