add new chat corrector

在助理最终执行物理SQL前,加入一步LLM优化性能功能
This commit is contained in:
柯慕灵
2025-06-21 04:57:04 +08:00
parent 87355533b4
commit f899d23b63
14 changed files with 214 additions and 51 deletions

View File

@@ -16,4 +16,7 @@ public class SqlInfo implements Serializable {
// SQL to be executed finally
private String querySQL;
// Physical SQL corrected by LLM for performance optimization
private String correctedQuerySQL;
}

View File

@@ -8,5 +8,6 @@ public enum ChatWorkflowState {
VALIDATING,
SQL_CORRECTING,
PROCESSING,
PHYSICAL_SQL_CORRECTING,
FINISHED
}

View File

@@ -0,0 +1,97 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.service.AiServices;
import lombok.Data;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* 物理SQL修正器 - 使用LLM优化物理SQL性能
*/
@Slf4j
public class LLMPhysicalSqlCorrector extends BaseSemanticCorrector {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY = "PHYSICAL_SQL_CORRECTOR";
private static final String INSTRUCTION = ""
+ "#Role: You are a senior database performance optimization expert experienced in SQL tuning."
+ "\n\n#Task: You will be provided with a user question and the corresponding physical SQL query,"
+ " please analyze and optimize this SQL to improve query performance." + "\n\n#Rules:"
+ "\n1. ALWAYS add appropriate index hints if the database supports them."
+ "\n2. Optimize JOIN order by placing smaller tables first."
+ "\n3. Add reasonable query limits to prevent large result sets if no LIMIT exists."
+ "\n4. Optimize WHERE condition order by placing high-selectivity conditions first."
+ "\n5. Ensure the optimized SQL is syntactically correct and logically equivalent."
+ "\n6. If the SQL is already well-optimized, return 'positive'."
+ "\n\n#Question: {{question}}" + "\n\n#OriginalSQL: {{sql}}" + "\n\n#Response:";
public LLMPhysicalSqlCorrector() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("物理SQL修正")
.appModule(AppModule.CHAT).description("通过大模型对物理SQL做性能优化").enable(false).build());
}
@Data
@ToString
static class PhysicalSql {
@Description("either positive or negative")
private String opinion;
@Description("optimized sql if negative")
private String sql;
}
interface PhysicalSqlExtractor {
PhysicalSql generatePhysicalSql(String text);
}
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
ChatApp chatApp = chatQueryContext.getRequest().getChatAppConfig().get(APP_KEY);
if (!chatQueryContext.getRequest().getText2SQLType().enableLLM() || Objects.isNull(chatApp)
|| !chatApp.isEnable()) {
return;
}
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
PhysicalSqlExtractor extractor =
AiServices.create(PhysicalSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
semanticParseInfo, chatApp.getPrompt());
PhysicalSql physicalSql =
extractor.generatePhysicalSql(prompt.toUserMessage().singleText());
keyPipelineLog.info("LLMPhysicalSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
physicalSql);
if ("NEGATIVE".equalsIgnoreCase(physicalSql.getOpinion())
&& StringUtils.isNotBlank(physicalSql.getSql())) {
semanticParseInfo.getSqlInfo().setCorrectedQuerySQL(physicalSql.getSql());
}
}
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
String promptTemplate) {
Map<String, Object> variable = new HashMap<>();
variable.put("question", queryText);
variable.put("sql", semanticParseInfo.getSqlInfo().getQuerySQL());
return PromptTemplate.from(promptTemplate).apply(variable);
}
}

View File

@@ -52,7 +52,8 @@ public class PromptHelper {
for (int i = 0; i < selfConsistencyNumber; i++) {
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
// only shuffle the exemplars from config
List<Text2SQLExemplar> subList=shuffledList.subList(llmReq.getDynamicExemplars().size(),shuffledList.size());
List<Text2SQLExemplar> subList =
shuffledList.subList(llmReq.getDynamicExemplars().size(), shuffledList.size());
Collections.shuffle(subList);
results.add(shuffledList.subList(0, Math.min(shuffledList.size(), fewShotNumber)));
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.LLMPhysicalSqlCorrector;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
@@ -76,6 +77,10 @@ public class ChatWorkflowEngine {
long start = System.currentTimeMillis();
performTranslating(queryCtx, parseResult);
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
queryCtx.setChatWorkflowState(ChatWorkflowState.PHYSICAL_SQL_CORRECTING);
break;
case PHYSICAL_SQL_CORRECTING:
performPhysicalSqlCorrecting(queryCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
break;
default:
@@ -162,4 +167,25 @@ public class ChatWorkflowEngine {
parseResult.setErrorMsg(String.join("\n", errorMsg));
}
}
private void performPhysicalSqlCorrecting(ChatQueryContext queryCtx) {
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
for (SemanticCorrector corrector : semanticCorrectors) {
if (corrector instanceof LLMPhysicalSqlCorrector) {
corrector.correct(queryCtx, semanticQuery.getParseInfo());
// 如果物理SQL被修正了更新querySQL为修正后的版本
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedQuerySQL())) {
parseInfo.getSqlInfo().setQuerySQL(parseInfo.getSqlInfo().getCorrectedQuerySQL());
log.info("Physical SQL corrected and updated querySQL: {}",
parseInfo.getSqlInfo().getQuerySQL());
}
break;
}
}
}
}
}
}