(improvement)(semantic) support llm multiple parsing sql (#290)

This commit is contained in:
lexluo09
2023-10-25 22:23:15 +08:00
committed by GitHub
parent e44e7ca8d5
commit 32e51257f6
4 changed files with 41 additions and 42 deletions

View File

@@ -69,44 +69,61 @@ public class LLMS2QLParser implements SemanticParser {
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig);
return;
}
//1.determine whether to skip this parser.
if (SatisfactionChecker.check(queryCtx)) {
log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText());
return;
}
try {
//2.get modelId from queryCtx and chatCtx.
Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId());
if (Objects.isNull(modelId) || modelId <= 0) {
return;
}
//3.get agent tool and determine whether to skip this parser.
CommonAgentTool commonAgentTool = getParserTool(request, modelId);
if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMS2QLParser.class);
return;
}
//4.construct a request, call the API for the large model, and retrieve the results.
LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig);
LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig);
if (Objects.isNull(llmResp)) {
return;
}
//5. get and update parserInfo and corrector sql
Map<String, Double> sqlWeight = llmResp.getSqlWeight();
ParseResult parseResult = ParseResult.builder().request(request)
.commonAgentTool(commonAgentTool).llmReq(llmReq).llmResp(llmResp).build();
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult);
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
if (Objects.isNull(sqlWeight) || sqlWeight.size() <= 0) {
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, llmResp.getSqlOutput());
} else {
sqlWeight.forEach((sql, weight) -> {
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, sql);
});
}
} catch (Exception e) {
log.error("LLMS2QLParser error", e);
}
}
private void addParseInfo(QueryContext queryCtx, ParseResult parseResult, Long modelId,
CommonAgentTool commonAgentTool, String sql) {
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult);
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, sql);
parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql());
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
}
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.query.llm.s2ql;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
@@ -18,5 +19,5 @@ public class LLMResp {
private String schemaLinkStr;
private String correctorSql;
private Map<String, Double> sqlWeight;
}

View File

@@ -4,14 +4,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
@@ -42,13 +39,13 @@ public class S2QLQuery extends PluginSemanticQuery {
@Override
public QueryResult execute(User user) {
LLMResp llmResp = getLlmResp();
long startTime = System.currentTimeMillis();
QueryS2QLReq queryS2QLReq = getQueryS2QLReq(llmResp);
String querySql = parseInfo.getSqlInfo().getLogicSql();
QueryS2QLReq queryS2QLReq = getQueryS2QLReq(querySql);
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2QL(queryS2QLReq, user);
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
QueryResult queryResult = new QueryResult();
if (Objects.nonNull(queryResp)) {
@@ -67,14 +64,8 @@ public class S2QLQuery extends PluginSemanticQuery {
return queryResult;
}
private LLMResp getLlmResp() {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class);
return parseResult.getLlmResp();
}
private QueryS2QLReq getQueryS2QLReq(LLMResp llmResp) {
return QueryReqBuilder.buildS2QLReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
private QueryS2QLReq getQueryS2QLReq(String sql) {
return QueryReqBuilder.buildS2QLReq(sql, parseInfo.getModelId());
}
@Override
@@ -83,7 +74,7 @@ public class S2QLQuery extends PluginSemanticQuery {
try {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(getQueryS2QLReq(getLlmResp()))
.queryReq(getQueryS2QLReq(parseInfo.getSqlInfo().getLogicSql()))
.build();
return semanticInterpreter.explain(explainSqlReq, user);
} catch (Exception e) {

View File

@@ -19,14 +19,12 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
@@ -36,7 +34,6 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -339,10 +336,8 @@ public class QueryServiceImpl implements QueryService {
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class);
LLMResp llmResp = parseResult.getLlmResp();
String correctorSql = llmResp.getCorrectorSql();
String correctorSql = parseInfo.getSqlInfo().getLogicSql();
log.info("correctorSql before replacing:{}", correctorSql);
List<FilterExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
List<FilterExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
@@ -366,11 +361,6 @@ public class QueryServiceImpl implements QueryService {
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
llmResp.setCorrectorSql(correctorSql);
parseResult.setLlmResp(llmResp);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);
parseInfo.setProperties(properties);
parseInfo.getSqlInfo().setLogicSql(correctorSql);
semanticQuery.setParseInfo(parseInfo);
ExplainResp explain = semanticQuery.explain(user);