(improvement)(semantic) add weight to multi sql in llmParser (#293)

This commit is contained in:
lexluo09
2023-10-26 16:42:01 +08:00
committed by GitHub
parent 32e51257f6
commit 38099c8cc7

View File

@@ -100,10 +100,10 @@ public class LLMS2QLParser implements SemanticParser {
.commonAgentTool(commonAgentTool).llmReq(llmReq).llmResp(llmResp).build();
if (Objects.isNull(sqlWeight) || sqlWeight.size() <= 0) {
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, llmResp.getSqlOutput());
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, llmResp.getSqlOutput(), 1D);
} else {
sqlWeight.forEach((sql, weight) -> {
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, sql);
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, sql, weight);
});
}
@@ -113,9 +113,9 @@ public class LLMS2QLParser implements SemanticParser {
}
private void addParseInfo(QueryContext queryCtx, ParseResult parseResult, Long modelId,
CommonAgentTool commonAgentTool, String sql) {
CommonAgentTool commonAgentTool, String sql, Double weight) {
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult);
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult, weight);
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, sql);
@@ -274,7 +274,10 @@ public class LLMS2QLParser implements SemanticParser {
}
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, CommonAgentTool commonAgentTool,
ParseResult parseResult) {
ParseResult parseResult, Double weight) {
if (Objects.isNull(weight)) {
weight = 0D;
}
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
@@ -285,7 +288,7 @@ public class LLMS2QLParser implements SemanticParser {
properties.put("name", commonAgentTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getRequest().getQueryText().length());
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2QL(parseResult.getLlmResp().getSqlOutput());