[feature][headless-chat]Introduce LLM-based semantic corrector.#1737

This commit is contained in:
jerryjzhang
2024-10-11 10:32:39 +08:00
parent 0fe7f0e1a7
commit 89b428b39c
3 changed files with 85 additions and 4 deletions

View File

@@ -0,0 +1,79 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.service.AiServices;
import lombok.Data;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
@Slf4j
public class LLMSqlCorrector extends BaseSemanticCorrector {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String INSTRUCTION = ""
+ "\n#Role: You are a senior data engineer experienced in writing SQL."
+ "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer,"
+ "please take a review and help correct it if necessary."
+ "\n#Rules: "
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n4.ALWAYS use `with` statement if nested aggregation is needed."
+ "\n5.ALWAYS enclose alias created by `AS` command in underscores."
+ "\n6.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
@Data
@ToString
static class SemanticSql {
@Description("positive or negative opinion about the sql")
private String opinion;
@Description("corrected sql")
private String sql;
}
interface SemanticSqlExtractor {
SemanticSql generateSemanticSql(String text);
}
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
if (!chatQueryContext.getText2SQLType().enableLLM()) {
return;
}
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatQueryContext.getModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo);
keyPipelineLog.info("LLMSqlCorrector reqPrompt:\n{}", prompt.text());
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
keyPipelineLog.info("LLMSqlCorrector modelResp:\n{}", s2Sql);
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(s2Sql.getSql());
}
}
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo) {
Map<String, Object> variable = new HashMap<>();
variable.put("question", queryText);
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
String promptTemplate = INSTRUCTION;
return PromptTemplate.from(promptTemplate).apply(variable);
}
}