From f899d23b6320719b29c71cf1a894903ea1168e53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9F=AF=E6=85=95=E7=81=B5?= <1985312383@qq.com> Date: Sat, 21 Jun 2025 04:57:04 +0800 Subject: [PATCH] add new chat corrector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在助理最终执行物理SQL前,加入一步LLM优化性能功能 --- .../chat/server/executor/SqlExecutor.java | 10 +- .../server/service/impl/AgentServiceImpl.java | 2 + .../service/impl/ChatQueryServiceImpl.java | 2 + .../supersonic/headless/api/pojo/SqlInfo.java | 3 + .../api/pojo/enums/ChatWorkflowState.java | 1 + .../corrector/LLMPhysicalSqlCorrector.java | 97 ++++++++++++++++++ .../chat/parser/llm/PromptHelper.java | 3 +- .../server/utils/ChatWorkflowEngine.java | 26 +++++ .../main/resources/META-INF/spring.factories | 3 +- .../main/resources/META-INF/spring.factories | 3 +- pom.xml | 10 +- webapp/package.json | 5 + webapp/packages/chat-sdk/src/common/type.ts | 1 + .../src/components/ChatItem/SqlItem.tsx | 99 ++++++++++--------- 14 files changed, 214 insertions(+), 51 deletions(-) create mode 100644 headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMPhysicalSqlCorrector.java create mode 100644 webapp/package.json diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index 0871dbae8..02a5df72d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -75,8 +75,12 @@ public class SqlExecutor implements ChatQueryExecutor { return null; } - QuerySqlReq sqlReq = - QuerySqlReq.builder().sql(parseInfo.getSqlInfo().getCorrectedS2SQL()).build(); + // 使用querySQL,它已经包含了所有修正(包括物理SQL修正) + String finalSql = StringUtils.isNotBlank(parseInfo.getSqlInfo().getQuerySQL()) + ? parseInfo.getSqlInfo().getQuerySQL() + : parseInfo.getSqlInfo().getCorrectedS2SQL(); + + QuerySqlReq sqlReq = QuerySqlReq.builder().sql(finalSql).build(); sqlReq.setSqlInfo(parseInfo.getSqlInfo()); sqlReq.setDataSetId(parseInfo.getDataSetId()); @@ -90,7 +94,7 @@ public class SqlExecutor implements ChatQueryExecutor { queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime); if (queryResp != null) { queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); - queryResult.setQuerySql(queryResp.getSql()); + queryResult.setQuerySql(finalSql); queryResult.setQueryResults(queryResp.getResultList()); queryResult.setQueryColumns(queryResp.getColumns()); queryResult.setQueryState(QueryState.SUCCESS); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 2d76875bb..f09a4e30d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @@ -39,6 +40,7 @@ public class AgentServiceImpl extends ServiceImpl implem private MemoryService memoryService; @Autowired + @Lazy private ChatQueryService chatQueryService; @Autowired diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 484e22305..1aca41e2d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -49,6 +49,7 @@ import net.sf.jsqlparser.schema.Column; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @@ -66,6 +67,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { @Autowired private SemanticLayerService semanticLayerService; @Autowired + @Lazy private AgentService agentService; private final List chatQueryParsers = ComponentFactory.getChatParsers(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlInfo.java index 3eab2bccf..ede33c71a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlInfo.java @@ -16,4 +16,7 @@ public class SqlInfo implements Serializable { // SQL to be executed finally private String querySQL; + + // Physical SQL corrected by LLM for performance optimization + private String correctedQuerySQL; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java index 953f1f020..79ba4d77d 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java @@ -8,5 +8,6 @@ public enum ChatWorkflowState { VALIDATING, SQL_CORRECTING, PROCESSING, + PHYSICAL_SQL_CORRECTING, FINISHED } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMPhysicalSqlCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMPhysicalSqlCorrector.java new file mode 100644 index 000000000..48a61f9ad --- /dev/null +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMPhysicalSqlCorrector.java @@ -0,0 +1,97 @@ +package com.tencent.supersonic.headless.chat.corrector; + +import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.enums.AppModule; +import com.tencent.supersonic.common.util.ChatAppManager; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.output.structured.Description; +import dev.langchain4j.provider.ModelProvider; +import dev.langchain4j.service.AiServices; +import lombok.Data; +import lombok.ToString; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * 物理SQL修正器 - 使用LLM优化物理SQL性能 + */ +@Slf4j +public class LLMPhysicalSqlCorrector extends BaseSemanticCorrector { + + private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); + + public static final String APP_KEY = "PHYSICAL_SQL_CORRECTOR"; + private static final String INSTRUCTION = "" + + "#Role: You are a senior database performance optimization expert experienced in SQL tuning." + + "\n\n#Task: You will be provided with a user question and the corresponding physical SQL query," + + " please analyze and optimize this SQL to improve query performance." + "\n\n#Rules:" + + "\n1. ALWAYS add appropriate index hints if the database supports them." + + "\n2. Optimize JOIN order by placing smaller tables first." + + "\n3. Add reasonable query limits to prevent large result sets if no LIMIT exists." + + "\n4. Optimize WHERE condition order by placing high-selectivity conditions first." + + "\n5. Ensure the optimized SQL is syntactically correct and logically equivalent." + + "\n6. If the SQL is already well-optimized, return 'positive'." + + "\n\n#Question: {{question}}" + "\n\n#OriginalSQL: {{sql}}" + "\n\n#Response:"; + + public LLMPhysicalSqlCorrector() { + ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("物理SQL修正") + .appModule(AppModule.CHAT).description("通过大模型对物理SQL做性能优化").enable(false).build()); + } + + @Data + @ToString + static class PhysicalSql { + @Description("either positive or negative") + private String opinion; + + @Description("optimized sql if negative") + private String sql; + } + + interface PhysicalSqlExtractor { + PhysicalSql generatePhysicalSql(String text); + } + + @Override + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + ChatApp chatApp = chatQueryContext.getRequest().getChatAppConfig().get(APP_KEY); + if (!chatQueryContext.getRequest().getText2SQLType().enableLLM() || Objects.isNull(chatApp) + || !chatApp.isEnable()) { + return; + } + + ChatLanguageModel chatLanguageModel = + ModelProvider.getChatModel(chatApp.getChatModelConfig()); + PhysicalSqlExtractor extractor = + AiServices.create(PhysicalSqlExtractor.class, chatLanguageModel); + Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(), + semanticParseInfo, chatApp.getPrompt()); + PhysicalSql physicalSql = + extractor.generatePhysicalSql(prompt.toUserMessage().singleText()); + keyPipelineLog.info("LLMPhysicalSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), + physicalSql); + if ("NEGATIVE".equalsIgnoreCase(physicalSql.getOpinion()) + && StringUtils.isNotBlank(physicalSql.getSql())) { + semanticParseInfo.getSqlInfo().setCorrectedQuerySQL(physicalSql.getSql()); + } + } + + private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo, + String promptTemplate) { + Map variable = new HashMap<>(); + variable.put("question", queryText); + variable.put("sql", semanticParseInfo.getSqlInfo().getQuerySQL()); + + return PromptTemplate.from(promptTemplate).apply(variable); + } +} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index c5d935b3b..a319b8491 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -52,7 +52,8 @@ public class PromptHelper { for (int i = 0; i < selfConsistencyNumber; i++) { List shuffledList = new ArrayList<>(exemplars); // only shuffle the exemplars from config - List subList=shuffledList.subList(llmReq.getDynamicExemplars().size(),shuffledList.size()); + List subList = + shuffledList.subList(llmReq.getDynamicExemplars().size(), shuffledList.size()); Collections.shuffle(subList); results.add(shuffledList.subList(0, Math.min(shuffledList.size(), fewShotNumber))); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 6fcbb983a..044fff206 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -8,6 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.corrector.LLMPhysicalSqlCorrector; import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector; import com.tencent.supersonic.headless.chat.mapper.SchemaMapper; import com.tencent.supersonic.headless.chat.parser.SemanticParser; @@ -76,6 +77,10 @@ public class ChatWorkflowEngine { long start = System.currentTimeMillis(); performTranslating(queryCtx, parseResult); parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start); + queryCtx.setChatWorkflowState(ChatWorkflowState.PHYSICAL_SQL_CORRECTING); + break; + case PHYSICAL_SQL_CORRECTING: + performPhysicalSqlCorrecting(queryCtx); queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); break; default: @@ -162,4 +167,25 @@ public class ChatWorkflowEngine { parseResult.setErrorMsg(String.join("\n", errorMsg)); } } + + private void performPhysicalSqlCorrecting(ChatQueryContext queryCtx) { + List candidateQueries = queryCtx.getCandidateQueries(); + if (CollectionUtils.isNotEmpty(candidateQueries)) { + for (SemanticQuery semanticQuery : candidateQueries) { + for (SemanticCorrector corrector : semanticCorrectors) { + if (corrector instanceof LLMPhysicalSqlCorrector) { + corrector.correct(queryCtx, semanticQuery.getParseInfo()); + // 如果物理SQL被修正了,更新querySQL为修正后的版本 + SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); + if (StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedQuerySQL())) { + parseInfo.getSqlInfo().setQuerySQL(parseInfo.getSqlInfo().getCorrectedQuerySQL()); + log.info("Physical SQL corrected and updated querySQL: {}", + parseInfo.getSqlInfo().getQuerySQL()); + } + break; + } + } + } + } + } } diff --git a/launchers/headless/src/main/resources/META-INF/spring.factories b/launchers/headless/src/main/resources/META-INF/spring.factories index 0adcac744..cd5643357 100644 --- a/launchers/headless/src/main/resources/META-INF/spring.factories +++ b/launchers/headless/src/main/resources/META-INF/spring.factories @@ -14,7 +14,8 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\ com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\ com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector,\ - com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector + com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector,\ + com.tencent.supersonic.headless.chat.corrector.LLMPhysicalSqlCorrector com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\ com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 8b17394eb..c12dfcea5 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -15,7 +15,8 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\ com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\ com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector,\ - com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector + com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector,\ + com.tencent.supersonic.headless.chat.corrector.LLMPhysicalSqlCorrector com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\ com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl diff --git a/pom.xml b/pom.xml index 8db6a8336..1a81d9238 100644 --- a/pom.xml +++ b/pom.xml @@ -10,9 +10,9 @@ ${revision} + common auth chat - common launchers headless @@ -31,6 +31,7 @@ 21 21 21 + 21 UTF-8 4.9 6.1.0 @@ -254,6 +255,13 @@ ${java.source.version} ${java.target.version} ${file.encoding} + + + org.projectlombok + lombok + ${lombok.version} + + diff --git a/webapp/package.json b/webapp/package.json new file mode 100644 index 000000000..cc1b48756 --- /dev/null +++ b/webapp/package.json @@ -0,0 +1,5 @@ +{ + "dependencies": { + "supersonic-chat-sdk": "link:packages/chat-sdk" + } +} diff --git a/webapp/packages/chat-sdk/src/common/type.ts b/webapp/packages/chat-sdk/src/common/type.ts index da70d4c51..e706311e4 100644 --- a/webapp/packages/chat-sdk/src/common/type.ts +++ b/webapp/packages/chat-sdk/src/common/type.ts @@ -79,6 +79,7 @@ export type SqlInfoType = { parsedS2SQL: string; correctedS2SQL: string; querySQL: string; + correctedQuerySQL?: string; }; export type ChatContextType = { diff --git a/webapp/packages/chat-sdk/src/components/ChatItem/SqlItem.tsx b/webapp/packages/chat-sdk/src/components/ChatItem/SqlItem.tsx index 96439d770..d33275683 100644 --- a/webapp/packages/chat-sdk/src/components/ChatItem/SqlItem.tsx +++ b/webapp/packages/chat-sdk/src/components/ChatItem/SqlItem.tsx @@ -58,32 +58,28 @@ const SqlItem: React.FC = ({ const getSchemaMapText = () => { return ` Schema映射 -${schema?.fieldNameList?.length > 0 ? `名称:${schema.fieldNameList.join('、')}` : ''}${ - schema?.values?.length > 0 +${schema?.fieldNameList?.length > 0 ? `名称:${schema.fieldNameList.join('、')}` : ''}${schema?.values?.length > 0 ? ` 取值:${schema.values - .map((item: any) => { - return `${item.fieldName}: ${item.fieldValue}`; - }) - .join('、')}` + .map((item: any) => { + return `${item.fieldName}: ${item.fieldValue}`; + }) + .join('、')}` : '' - }${ - priorExts + }${priorExts ? ` 附加:${priorExts}` : '' - }${ - terms?.length > 0 + }${terms?.length > 0 ? ` 术语:${terms - .map((item: any) => { - return `${item.name}${item.alias?.length > 0 ? `(${item.alias.join(',')})` : ''}: ${ - item.description + .map((item: any) => { + return `${item.name}${item.alias?.length > 0 ? `(${item.alias.join(',')})` : ''}: ${item.description }`; - }) - .join('、')}` + }) + .join('、')}` : '' - } + } `; }; @@ -91,16 +87,16 @@ ${schema?.fieldNameList?.length > 0 ? `名称:${schema.fieldNameList.join('、 const getFewShotText = () => { return ` Few-shot示例${fewShots - .map((item: any, index: number) => { - return ` + .map((item: any, index: number) => { + return ` 示例${index + 1}: 问题:${item.question} SQL: ${format(item.sql)} `; - }) - .join('')} + }) + .join('')} `; }; @@ -120,6 +116,14 @@ ${format(sqlInfo.correctedS2SQL)} `; }; + const getCorrectedQuerySQLText = () => { + return ` +物理SQL修正 + +${format(sqlInfo.correctedQuerySQL || '')} +`; + }; + const getQuerySQLText = () => { return ` 最终执行SQL @@ -155,6 +159,9 @@ ${executeErrorMsg} if (sqlInfo.correctedS2SQL) { text += getCorrectedS2SQLText(); } + if (sqlInfo.correctedQuerySQL) { + text += getCorrectedQuerySQLText(); + } if (sqlInfo.querySQL) { text += getQuerySQLText(); } @@ -183,9 +190,8 @@ ${executeErrorMsg}
{llmReq && (
{ setSqlType(sqlType === 'schemaMap' ? '' : 'schemaMap'); }} @@ -195,9 +201,8 @@ ${executeErrorMsg} )} {fewShots.length > 0 && (
{ setSqlType(sqlType === 'fewShots' ? '' : 'fewShots'); }} @@ -207,9 +212,8 @@ ${executeErrorMsg} )} {sqlInfo.parsedS2SQL && (
{ setSqlType(sqlType === 'parsedS2SQL' ? '' : 'parsedS2SQL'); }} @@ -219,9 +223,8 @@ ${executeErrorMsg} )} {sqlInfo.correctedS2SQL && (
{ setSqlType(sqlType === 'correctedS2SQL' ? '' : 'correctedS2SQL'); }} @@ -229,16 +232,26 @@ ${executeErrorMsg} 修正S2SQL
)} + {sqlInfo.correctedQuerySQL && ( +
{ + setSqlType(sqlType === 'correctedQuerySQL' ? '' : 'correctedQuerySQL'); + }} + > + 物理SQL修正 +
+ )} {sqlInfo.querySQL && (
{ setSqlType(sqlType === 'querySQL' ? '' : 'querySQL'); }} > - 最终执行SQL + {sqlInfo.correctedQuerySQL ? '最终执行SQL' : '最终执行SQL'}
)}
{sqlType === 'schemaMap' && (
@@ -290,9 +302,8 @@ ${executeErrorMsg}
{terms .map((item: any) => { - return `${item.name}${ - item.alias?.length > 0 ? `(${item.alias.join(',')})` : '' - }: ${item.description}`; + return `${item.name}${item.alias?.length > 0 ? `(${item.alias.join(',')})` : '' + }: ${item.description}`; }) .join('、')}