From 32e51257f6d721b974bcbe81708263ed327c12ed Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Wed, 25 Oct 2023 22:23:15 +0800 Subject: [PATCH] (improvement)(semantic) support llm multiple parsing sql (#290) --- .../chat/parser/llm/s2ql/LLMS2QLParser.java | 45 +++++++++++++------ .../chat/query/llm/s2ql/LLMResp.java | 3 +- .../chat/query/llm/s2ql/S2QLQuery.java | 21 +++------ .../chat/service/impl/QueryServiceImpl.java | 14 +----- 4 files changed, 41 insertions(+), 42 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParser.java index 69fa1e77d..e089ae2c7 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParser.java @@ -69,44 +69,61 @@ public class LLMS2QLParser implements SemanticParser { log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig); return; } + //1.determine whether to skip this parser. if (SatisfactionChecker.check(queryCtx)) { log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText()); return; } try { + //2.get modelId from queryCtx and chatCtx. Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId()); if (Objects.isNull(modelId) || modelId <= 0) { return; } - + //3.get agent tool and determine whether to skip this parser. CommonAgentTool commonAgentTool = getParserTool(request, modelId); if (Objects.isNull(commonAgentTool)) { log.info("no tool in this agent, skip {}", LLMS2QLParser.class); return; } - + //4.construct a request, call the API for the large model, and retrieve the results. LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig); LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig); if (Objects.isNull(llmResp)) { return; } + //5. get and update parserInfo and corrector sql + Map sqlWeight = llmResp.getSqlWeight(); + ParseResult parseResult = ParseResult.builder().request(request) .commonAgentTool(commonAgentTool).llmReq(llmReq).llmResp(llmResp).build(); - SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult); - - SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput()); - - llmResp.setCorrectorSql(semanticCorrectInfo.getSql()); - - updateParseInfo(semanticCorrectInfo, modelId, parseInfo); + if (Objects.isNull(sqlWeight) || sqlWeight.size() <= 0) { + addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, llmResp.getSqlOutput()); + } else { + sqlWeight.forEach((sql, weight) -> { + addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, sql); + }); + } } catch (Exception e) { log.error("LLMS2QLParser error", e); } } + private void addParseInfo(QueryContext queryCtx, ParseResult parseResult, Long modelId, + CommonAgentTool commonAgentTool, String sql) { + + SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult); + + SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, sql); + + parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql()); + + updateParseInfo(semanticCorrectInfo, modelId, parseInfo); + } + private Set getElements(Long modelId, List allFields, List elements) { return elements.stream() .filter(schemaElement -> modelId.equals(schemaElement.getModel()) @@ -172,7 +189,7 @@ public class LLMS2QLParser implements SemanticParser { } private List getDimensionFilter(Map fieldNameToElement, - List filterExpressions) { + List filterExpressions) { List result = Lists.newArrayList(); for (FilterExpression expression : filterExpressions) { QueryFilter dimensionFilter = new QueryFilter(); @@ -229,7 +246,7 @@ public class LLMS2QLParser implements SemanticParser { } private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator, - FilterOperatorEnum... operatorEnums) { + FilterOperatorEnum... operatorEnums) { return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); } @@ -257,7 +274,7 @@ public class LLMS2QLParser implements SemanticParser { } private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, CommonAgentTool commonAgentTool, - ParseResult parseResult) { + ParseResult parseResult) { PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE); SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId)); @@ -414,7 +431,7 @@ public class LLMS2QLParser implements SemanticParser { protected List getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema, - LLMParserConfig llmParserConfig) { + LLMParserConfig llmParserConfig) { Set results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig); @@ -450,7 +467,7 @@ public class LLMS2QLParser implements SemanticParser { } private Set getTopNFieldNames(Long modelId, SemanticSchema semanticSchema, - LLMParserConfig llmParserConfig) { + LLMParserConfig llmParserConfig) { Set results = semanticSchema.getDimensions(modelId).stream() .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) .limit(llmParserConfig.getDimensionTopN()) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/LLMResp.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/LLMResp.java index 9563a0ee4..e68f669a5 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/LLMResp.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/LLMResp.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.query.llm.s2ql; import java.util.List; +import java.util.Map; import lombok.Data; @Data @@ -18,5 +19,5 @@ public class LLMResp { private String schemaLinkStr; - private String correctorSql; + private Map sqlWeight; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/S2QLQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/S2QLQuery.java index 687a99159..04d690526 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/S2QLQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/S2QLQuery.java @@ -4,14 +4,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; -import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; -import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.QueryColumn; -import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; @@ -42,13 +39,13 @@ public class S2QLQuery extends PluginSemanticQuery { @Override public QueryResult execute(User user) { - LLMResp llmResp = getLlmResp(); long startTime = System.currentTimeMillis(); - QueryS2QLReq queryS2QLReq = getQueryS2QLReq(llmResp); + String querySql = parseInfo.getSqlInfo().getLogicSql(); + QueryS2QLReq queryS2QLReq = getQueryS2QLReq(querySql); QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2QL(queryS2QLReq, user); - log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput()); + log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql); QueryResult queryResult = new QueryResult(); if (Objects.nonNull(queryResp)) { @@ -67,14 +64,8 @@ public class S2QLQuery extends PluginSemanticQuery { return queryResult; } - private LLMResp getLlmResp() { - String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); - ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class); - return parseResult.getLlmResp(); - } - - private QueryS2QLReq getQueryS2QLReq(LLMResp llmResp) { - return QueryReqBuilder.buildS2QLReq(llmResp.getCorrectorSql(), parseInfo.getModelId()); + private QueryS2QLReq getQueryS2QLReq(String sql) { + return QueryReqBuilder.buildS2QLReq(sql, parseInfo.getModelId()); } @Override @@ -83,7 +74,7 @@ public class S2QLQuery extends PluginSemanticQuery { try { explainSqlReq = ExplainSqlReq.builder() .queryTypeEnum(QueryTypeEnum.SQL) - .queryReq(getQueryS2QLReq(getLlmResp())) + .queryReq(getQueryS2QLReq(parseInfo.getSqlInfo().getLogicSql())) .build(); return semanticInterpreter.explain(explainSqlReq, user); } catch (Exception e) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index f0c2cace7..e2914f706 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -19,14 +19,12 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; -import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QuerySelector; -import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery; import com.tencent.supersonic.chat.responder.execute.ExecuteResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder; @@ -36,7 +34,6 @@ import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.StatisticsService; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.SolvedQueryManager; -import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.util.ContextUtils; @@ -339,10 +336,8 @@ public class QueryServiceImpl implements QueryService { if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { Map> filedNameToValueMap = new HashMap<>(); Map> havingFiledNameToValueMap = new HashMap<>(); - String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); - ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class); - LLMResp llmResp = parseResult.getLlmResp(); - String correctorSql = llmResp.getCorrectorSql(); + + String correctorSql = parseInfo.getSqlInfo().getLogicSql(); log.info("correctorSql before replacing:{}", correctorSql); List whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql); List havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql); @@ -366,11 +361,6 @@ public class QueryServiceImpl implements QueryService { correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions); log.info("correctorSql after replacing:{}", correctorSql); - llmResp.setCorrectorSql(correctorSql); - parseResult.setLlmResp(llmResp); - Map properties = new HashMap<>(); - properties.put(Constants.CONTEXT, parseResult); - parseInfo.setProperties(properties); parseInfo.getSqlInfo().setLogicSql(correctorSql); semanticQuery.setParseInfo(parseInfo); ExplainResp explain = semanticQuery.explain(user);