From 89b428b39cb7993697608060a3d89bc57bbf0e50 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Fri, 11 Oct 2024 10:32:39 +0800 Subject: [PATCH] [feature][headless-chat]Introduce LLM-based semantic corrector.#1737 --- .../chat/server/memory/MemoryReviewTask.java | 9 ++- .../server/service/impl/AgentServiceImpl.java | 1 + .../chat/corrector/LLMSqlCorrector.java | 79 +++++++++++++++++++ 3 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index de225e886..8014a9256 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -31,10 +31,11 @@ public class MemoryReviewTask { private static final String INSTRUCTION = "" + "\n#Role: You are a senior data engineer experienced in writing SQL." - + "\n#Task: Your will be provided with a user question and the SQL written by junior engineer," - + "please take a review and give your opinion." + "\n#Rules: " - + "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." - + "2.NO NEED to include date filter in the where clause if not explicitly expressed in the `Question`." + + "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer," + + "please take a review and give your opinion." + + "\n#Rules: " + + "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." + + "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard." + "\n#Question: %s" + "\n#Schema: %s" + "\n#SideInfo: %s" + "\n#SQL: %s" + "\n#Response: "; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index dda61aa1b..a183a94e0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -106,6 +106,7 @@ public class AgentServiceImpl extends ServiceImpl implem private synchronized void doExecuteAgentExamples(Agent agent) { if (!agent.containsDatasetTool() + || !agent.enableMemoryReview() || !ModelConfigHelper.testConnection( ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL)) || CollectionUtils.isEmpty(agent.getExamples())) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java new file mode 100644 index 000000000..31d1467e3 --- /dev/null +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java @@ -0,0 +1,79 @@ +package com.tencent.supersonic.headless.chat.corrector; + +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.output.structured.Description; +import dev.langchain4j.provider.ModelProvider; +import dev.langchain4j.service.AiServices; +import lombok.Data; +import lombok.ToString; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; + +@Slf4j +public class LLMSqlCorrector extends BaseSemanticCorrector { + + private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); + + public static final String INSTRUCTION = "" + + "\n#Role: You are a senior data engineer experienced in writing SQL." + + "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer," + + "please take a review and help correct it if necessary." + + "\n#Rules: " + + "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`." + + "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard." + + "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + + "\n4.ALWAYS use `with` statement if nested aggregation is needed." + + "\n5.ALWAYS enclose alias created by `AS` command in underscores." + + "\n6.ALWAYS translate alias created by `AS` command to the same language as the `#Question`." + + "\n#Question:{{question}} #InputSQL:{{sql}} #Response:"; + + @Data + @ToString + static class SemanticSql { + @Description("positive or negative opinion about the sql") + private String opinion; + + @Description("corrected sql") + private String sql; + } + + interface SemanticSqlExtractor { + SemanticSql generateSemanticSql(String text); + } + + @Override + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + if (!chatQueryContext.getText2SQLType().enableLLM()) { + return; + } + + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatQueryContext.getModelConfig()); + SemanticSqlExtractor extractor = + AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); + Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo); + keyPipelineLog.info("LLMSqlCorrector reqPrompt:\n{}", prompt.text()); + SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText()); + keyPipelineLog.info("LLMSqlCorrector modelResp:\n{}", s2Sql); + if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) { + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(s2Sql.getSql()); + } + } + + private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo) { + Map variable = new HashMap<>(); + variable.put("question", queryText); + variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); + + String promptTemplate = INSTRUCTION; + return PromptTemplate.from(promptTemplate).apply(variable); + } +}