mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][chat]Inject schema info into the prompt of LLMSqlCorrector.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
@@ -34,11 +35,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
+ "please take a review and help correct it if necessary." + "\n#Rules: "
|
||||
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
|
||||
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
|
||||
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||
+ "\n4.ALWAYS use `with` statement if nested aggregation is needed."
|
||||
+ "\n5.ALWAYS enclose alias declared by `AS` command in underscores."
|
||||
+ "\n6.Alias created by `AS` command must be in the same language ast the `Question`."
|
||||
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
|
||||
+ "\n3.SQL columns and values must be mentioned in the `#Schema`."
|
||||
+ "\n#Question:{{question}} #Schema:{{schema}} #InputSQL:{{sql}} #Response:";
|
||||
|
||||
public LLMSqlCorrector() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
|
||||
@@ -67,12 +65,15 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
return;
|
||||
}
|
||||
|
||||
Text2SQLExemplar exemplar = (Text2SQLExemplar)semanticParseInfo.getProperties()
|
||||
.get(Text2SQLExemplar.PROPERTY_KEY);
|
||||
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
|
||||
semanticParseInfo, chatApp.getPrompt());
|
||||
semanticParseInfo, chatApp.getPrompt(), exemplar);
|
||||
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
|
||||
keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql);
|
||||
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {
|
||||
@@ -81,10 +82,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
|
||||
String promptTemplate) {
|
||||
String promptTemplate, Text2SQLExemplar exemplar) {
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("question", queryText);
|
||||
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
variable.put("schema", exemplar.getDbSchema());
|
||||
|
||||
return PromptTemplate.from(promptTemplate).apply(variable);
|
||||
}
|
||||
|
||||
@@ -46,7 +46,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
|
||||
optimizer.rewrite(queryStatement);
|
||||
}
|
||||
}
|
||||
log.info("translated query SQL: [{}]", queryStatement.getSql());
|
||||
log.info("translated query SQL: [{}]",
|
||||
StringUtils.normalizeSpace(queryStatement.getSql()));
|
||||
} catch (Exception e) {
|
||||
queryStatement.setErrMsg(e.getMessage());
|
||||
log.error("Failed to translate query [{}]", e.getMessage(), e);
|
||||
|
||||
Reference in New Issue
Block a user