mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
[improvement][chat]Inject schema info into the prompt of LLMSqlCorrector.
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.corrector;
|
package com.tencent.supersonic.headless.chat.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||||
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
@@ -34,11 +35,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
|||||||
+ "please take a review and help correct it if necessary." + "\n#Rules: "
|
+ "please take a review and help correct it if necessary." + "\n#Rules: "
|
||||||
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
|
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
|
||||||
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
|
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
|
||||||
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
+ "\n3.SQL columns and values must be mentioned in the `#Schema`."
|
||||||
+ "\n4.ALWAYS use `with` statement if nested aggregation is needed."
|
+ "\n#Question:{{question}} #Schema:{{schema}} #InputSQL:{{sql}} #Response:";
|
||||||
+ "\n5.ALWAYS enclose alias declared by `AS` command in underscores."
|
|
||||||
+ "\n6.Alias created by `AS` command must be in the same language ast the `Question`."
|
|
||||||
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
|
|
||||||
|
|
||||||
public LLMSqlCorrector() {
|
public LLMSqlCorrector() {
|
||||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
|
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
|
||||||
@@ -67,12 +65,15 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Text2SQLExemplar exemplar = (Text2SQLExemplar)semanticParseInfo.getProperties()
|
||||||
|
.get(Text2SQLExemplar.PROPERTY_KEY);
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel =
|
ChatLanguageModel chatLanguageModel =
|
||||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||||
SemanticSqlExtractor extractor =
|
SemanticSqlExtractor extractor =
|
||||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||||
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
|
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
|
||||||
semanticParseInfo, chatApp.getPrompt());
|
semanticParseInfo, chatApp.getPrompt(), exemplar);
|
||||||
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
|
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
|
||||||
keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql);
|
keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql);
|
||||||
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {
|
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {
|
||||||
@@ -81,10 +82,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
|
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
|
||||||
String promptTemplate) {
|
String promptTemplate, Text2SQLExemplar exemplar) {
|
||||||
Map<String, Object> variable = new HashMap<>();
|
Map<String, Object> variable = new HashMap<>();
|
||||||
variable.put("question", queryText);
|
variable.put("question", queryText);
|
||||||
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||||
|
variable.put("schema", exemplar.getDbSchema());
|
||||||
|
|
||||||
return PromptTemplate.from(promptTemplate).apply(variable);
|
return PromptTemplate.from(promptTemplate).apply(variable);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
|
|||||||
for (QueryOptimizer queryOptimizer : ComponentFactory.getQueryOptimizers()) {
|
for (QueryOptimizer queryOptimizer : ComponentFactory.getQueryOptimizers()) {
|
||||||
queryOptimizer.rewrite(queryStatement);
|
queryOptimizer.rewrite(queryStatement);
|
||||||
}
|
}
|
||||||
|
log.info("translated query SQL: [{}]",
|
||||||
|
StringUtils.normalizeSpace(queryStatement.getSql()));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
queryStatement.setErrMsg(e.getMessage());
|
queryStatement.setErrMsg(e.getMessage());
|
||||||
log.error("Failed to translate query [{}]", e.getMessage(), e);
|
log.error("Failed to translate query [{}]", e.getMessage(), e);
|
||||||
|
|||||||
Reference in New Issue
Block a user