(improvement)(chat) Optimize the LLMSqlParser code. (#1195)

This commit is contained in:
lexluo09
2024-06-22 22:51:43 +08:00
committed by GitHub
parent 29694be64e
commit a658b9c45f

View File

@@ -38,17 +38,19 @@ public class LLMSqlParser implements SemanticParser {
log.info("Generate query statement for dataSetId:{}", dataSetId); log.info("Generate query statement for dataSetId:{}", dataSetId);
//3.invoke LLM service to do parsing. //3.invoke LLM service to do parsing.
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId); tryParse(queryCtx, dataSetId);
recall(queryCtx, dataSetId, llmReq);
} catch (Exception e) { } catch (Exception e) {
log.error("Failed to parse query:", e); log.error("Failed to parse query:", e);
} }
} }
private void recall(QueryContext queryCtx, Long dataSetId, LLMReq llmReq) { private void tryParse(QueryContext queryCtx, Long dataSetId) {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries(); int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries();
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId);
int currentRetry = 0; int currentRetry = 0;
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>(); Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
ParseResult parseResult = null; ParseResult parseResult = null;
@@ -60,9 +62,8 @@ public class LLMSqlParser implements SemanticParser {
//deduplicate the S2SQL result list and build parserInfo //deduplicate the S2SQL result list and build parserInfo
sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp); sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp);
if (MapUtils.isNotEmpty(sqlRespMap)) { if (MapUtils.isNotEmpty(sqlRespMap)) {
parseResult = ParseResult.builder() parseResult = ParseResult.builder().dataSetId(dataSetId).llmReq(llmReq)
.dataSetId(dataSetId).llmReq(llmReq).llmResp(llmResp) .llmResp(llmResp).linkingValues(llmReq.getLinking()).build();
.linkingValues(llmReq.getLinking()).build();
break; break;
} }
} }
@@ -71,12 +72,13 @@ public class LLMSqlParser implements SemanticParser {
} }
currentRetry++; currentRetry++;
} }
if (MapUtils.isNotEmpty(sqlRespMap)) { if (MapUtils.isEmpty(sqlRespMap)) {
for (Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) { return;
String sql = entry.getKey(); }
double sqlWeight = entry.getValue().getSqlWeight(); for (Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight); String sql = entry.getKey();
} double sqlWeight = entry.getValue().getSqlWeight();
responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight);
} }
} }