mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
add new chat corrector
在助理最终执行物理SQL前,加入一步LLM优化性能功能
This commit is contained in:
@@ -16,4 +16,7 @@ public class SqlInfo implements Serializable {
|
||||
|
||||
// SQL to be executed finally
|
||||
private String querySQL;
|
||||
|
||||
// Physical SQL corrected by LLM for performance optimization
|
||||
private String correctedQuerySQL;
|
||||
}
|
||||
|
||||
@@ -8,5 +8,6 @@ public enum ChatWorkflowState {
|
||||
VALIDATING,
|
||||
SQL_CORRECTING,
|
||||
PROCESSING,
|
||||
PHYSICAL_SQL_CORRECTING,
|
||||
FINISHED
|
||||
}
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 物理SQL修正器 - 使用LLM优化物理SQL性能
|
||||
*/
|
||||
@Slf4j
|
||||
public class LLMPhysicalSqlCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String APP_KEY = "PHYSICAL_SQL_CORRECTOR";
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a senior database performance optimization expert experienced in SQL tuning."
|
||||
+ "\n\n#Task: You will be provided with a user question and the corresponding physical SQL query,"
|
||||
+ " please analyze and optimize this SQL to improve query performance." + "\n\n#Rules:"
|
||||
+ "\n1. ALWAYS add appropriate index hints if the database supports them."
|
||||
+ "\n2. Optimize JOIN order by placing smaller tables first."
|
||||
+ "\n3. Add reasonable query limits to prevent large result sets if no LIMIT exists."
|
||||
+ "\n4. Optimize WHERE condition order by placing high-selectivity conditions first."
|
||||
+ "\n5. Ensure the optimized SQL is syntactically correct and logically equivalent."
|
||||
+ "\n6. If the SQL is already well-optimized, return 'positive'."
|
||||
+ "\n\n#Question: {{question}}" + "\n\n#OriginalSQL: {{sql}}" + "\n\n#Response:";
|
||||
|
||||
public LLMPhysicalSqlCorrector() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("物理SQL修正")
|
||||
.appModule(AppModule.CHAT).description("通过大模型对物理SQL做性能优化").enable(false).build());
|
||||
}
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
static class PhysicalSql {
|
||||
@Description("either positive or negative")
|
||||
private String opinion;
|
||||
|
||||
@Description("optimized sql if negative")
|
||||
private String sql;
|
||||
}
|
||||
|
||||
interface PhysicalSqlExtractor {
|
||||
PhysicalSql generatePhysicalSql(String text);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
ChatApp chatApp = chatQueryContext.getRequest().getChatAppConfig().get(APP_KEY);
|
||||
if (!chatQueryContext.getRequest().getText2SQLType().enableLLM() || Objects.isNull(chatApp)
|
||||
|| !chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
PhysicalSqlExtractor extractor =
|
||||
AiServices.create(PhysicalSqlExtractor.class, chatLanguageModel);
|
||||
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
|
||||
semanticParseInfo, chatApp.getPrompt());
|
||||
PhysicalSql physicalSql =
|
||||
extractor.generatePhysicalSql(prompt.toUserMessage().singleText());
|
||||
keyPipelineLog.info("LLMPhysicalSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
|
||||
physicalSql);
|
||||
if ("NEGATIVE".equalsIgnoreCase(physicalSql.getOpinion())
|
||||
&& StringUtils.isNotBlank(physicalSql.getSql())) {
|
||||
semanticParseInfo.getSqlInfo().setCorrectedQuerySQL(physicalSql.getSql());
|
||||
}
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
|
||||
String promptTemplate) {
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("question", queryText);
|
||||
variable.put("sql", semanticParseInfo.getSqlInfo().getQuerySQL());
|
||||
|
||||
return PromptTemplate.from(promptTemplate).apply(variable);
|
||||
}
|
||||
}
|
||||
@@ -52,7 +52,8 @@ public class PromptHelper {
|
||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||
// only shuffle the exemplars from config
|
||||
List<Text2SQLExemplar> subList=shuffledList.subList(llmReq.getDynamicExemplars().size(),shuffledList.size());
|
||||
List<Text2SQLExemplar> subList =
|
||||
shuffledList.subList(llmReq.getDynamicExemplars().size(), shuffledList.size());
|
||||
Collections.shuffle(subList);
|
||||
results.add(shuffledList.subList(0, Math.min(shuffledList.size(), fewShotNumber)));
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.corrector.LLMPhysicalSqlCorrector;
|
||||
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
|
||||
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
@@ -76,6 +77,10 @@ public class ChatWorkflowEngine {
|
||||
long start = System.currentTimeMillis();
|
||||
performTranslating(queryCtx, parseResult);
|
||||
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.PHYSICAL_SQL_CORRECTING);
|
||||
break;
|
||||
case PHYSICAL_SQL_CORRECTING:
|
||||
performPhysicalSqlCorrecting(queryCtx);
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||
break;
|
||||
default:
|
||||
@@ -162,4 +167,25 @@ public class ChatWorkflowEngine {
|
||||
parseResult.setErrorMsg(String.join("\n", errorMsg));
|
||||
}
|
||||
}
|
||||
|
||||
private void performPhysicalSqlCorrecting(ChatQueryContext queryCtx) {
|
||||
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
|
||||
if (CollectionUtils.isNotEmpty(candidateQueries)) {
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
for (SemanticCorrector corrector : semanticCorrectors) {
|
||||
if (corrector instanceof LLMPhysicalSqlCorrector) {
|
||||
corrector.correct(queryCtx, semanticQuery.getParseInfo());
|
||||
// 如果物理SQL被修正了,更新querySQL为修正后的版本
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
if (StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedQuerySQL())) {
|
||||
parseInfo.getSqlInfo().setQuerySQL(parseInfo.getSqlInfo().getCorrectedQuerySQL());
|
||||
log.info("Physical SQL corrected and updated querySQL: {}",
|
||||
parseInfo.getSqlInfo().getQuerySQL());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user