[improvement][chat]Inject schema info into the prompt of LLMSqlCorrector.

This commit is contained in:
jerryjzhang
2024-12-18 12:04:06 +08:00
parent 1d0f5612b7
commit 6fcd105249
2 changed files with 11 additions and 8 deletions

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -34,11 +35,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
+ "please take a review and help correct it if necessary." + "\n#Rules: "
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n4.ALWAYS use `with` statement if nested aggregation is needed."
+ "\n5.ALWAYS enclose alias declared by `AS` command in underscores."
+ "\n6.Alias created by `AS` command must be in the same language ast the `Question`."
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
+ "\n3.SQL columns and values must be mentioned in the `#Schema`."
+ "\n#Question:{{question}} #Schema:{{schema}} #InputSQL:{{sql}} #Response:";
public LLMSqlCorrector() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
@@ -67,12 +65,15 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
return;
}
Text2SQLExemplar exemplar = (Text2SQLExemplar)semanticParseInfo.getProperties()
.get(Text2SQLExemplar.PROPERTY_KEY);
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
semanticParseInfo, chatApp.getPrompt());
semanticParseInfo, chatApp.getPrompt(), exemplar);
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql);
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {
@@ -81,10 +82,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
}
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
String promptTemplate) {
String promptTemplate, Text2SQLExemplar exemplar) {
Map<String, Object> variable = new HashMap<>();
variable.put("question", queryText);
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
variable.put("schema", exemplar.getDbSchema());
return PromptTemplate.from(promptTemplate).apply(variable);
}

View File

@@ -46,7 +46,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
optimizer.rewrite(queryStatement);
}
}
log.info("translated query SQL: [{}]", queryStatement.getSql());
log.info("translated query SQL: [{}]",
StringUtils.normalizeSpace(queryStatement.getSql()));
} catch (Exception e) {
queryStatement.setErrMsg(e.getMessage());
log.error("Failed to translate query [{}]", e.getMessage(), e);