From a658b9c45f547e6d4967dcdc0e4fee559a039445 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sat, 22 Jun 2024 22:51:43 +0800 Subject: [PATCH] (improvement)(chat) Optimize the LLMSqlParser code. (#1195) --- .../chat/parser/llm/LLMSqlParser.java | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java index 3f8d6a7af..5309a7426 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java @@ -38,17 +38,19 @@ public class LLMSqlParser implements SemanticParser { log.info("Generate query statement for dataSetId:{}", dataSetId); //3.invoke LLM service to do parsing. - LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId); - recall(queryCtx, dataSetId, llmReq); + tryParse(queryCtx, dataSetId); } catch (Exception e) { log.error("Failed to parse query:", e); } } - private void recall(QueryContext queryCtx, Long dataSetId, LLMReq llmReq) { + private void tryParse(QueryContext queryCtx, Long dataSetId) { LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries(); + + LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId); + int currentRetry = 0; Map sqlRespMap = new HashMap<>(); ParseResult parseResult = null; @@ -60,9 +62,8 @@ public class LLMSqlParser implements SemanticParser { //deduplicate the S2SQL result list and build parserInfo sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp); if (MapUtils.isNotEmpty(sqlRespMap)) { - parseResult = ParseResult.builder() - .dataSetId(dataSetId).llmReq(llmReq).llmResp(llmResp) - .linkingValues(llmReq.getLinking()).build(); + parseResult = ParseResult.builder().dataSetId(dataSetId).llmReq(llmReq) + .llmResp(llmResp).linkingValues(llmReq.getLinking()).build(); break; } } @@ -71,12 +72,13 @@ public class LLMSqlParser implements SemanticParser { } currentRetry++; } - if (MapUtils.isNotEmpty(sqlRespMap)) { - for (Entry entry : sqlRespMap.entrySet()) { - String sql = entry.getKey(); - double sqlWeight = entry.getValue().getSqlWeight(); - responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight); - } + if (MapUtils.isEmpty(sqlRespMap)) { + return; + } + for (Entry entry : sqlRespMap.entrySet()) { + String sql = entry.getKey(); + double sqlWeight = entry.getValue().getSqlWeight(); + responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight); } }