(improvement)(semantic) support llm multiple parsing sql (#290)

This commit is contained in:
lexluo09
2023-10-25 22:23:15 +08:00
committed by GitHub
parent e44e7ca8d5
commit 32e51257f6
4 changed files with 41 additions and 42 deletions

View File

@@ -69,44 +69,61 @@ public class LLMS2QLParser implements SemanticParser {
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig);
return;
}
//1.determine whether to skip this parser.
if (SatisfactionChecker.check(queryCtx)) {
log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText());
return;
}
try {
//2.get modelId from queryCtx and chatCtx.
Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId());
if (Objects.isNull(modelId) || modelId <= 0) {
return;
}
//3.get agent tool and determine whether to skip this parser.
CommonAgentTool commonAgentTool = getParserTool(request, modelId);
if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMS2QLParser.class);
return;
}
//4.construct a request, call the API for the large model, and retrieve the results.
LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig);
LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig);
if (Objects.isNull(llmResp)) {
return;
}
//5. get and update parserInfo and corrector sql
Map<String, Double> sqlWeight = llmResp.getSqlWeight();
ParseResult parseResult = ParseResult.builder().request(request)
.commonAgentTool(commonAgentTool).llmReq(llmReq).llmResp(llmResp).build();
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult);
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
if (Objects.isNull(sqlWeight) || sqlWeight.size() <= 0) {
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, llmResp.getSqlOutput());
} else {
sqlWeight.forEach((sql, weight) -> {
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, sql);
});
}
} catch (Exception e) {
log.error("LLMS2QLParser error", e);
}
}
private void addParseInfo(QueryContext queryCtx, ParseResult parseResult, Long modelId,
CommonAgentTool commonAgentTool, String sql) {
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult);
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, sql);
parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql());
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
}
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
@@ -172,7 +189,7 @@ public class LLMS2QLParser implements SemanticParser {
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FilterExpression> filterExpressions) {
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
@@ -229,7 +246,7 @@ public class LLMS2QLParser implements SemanticParser {
}
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
@@ -257,7 +274,7 @@ public class LLMS2QLParser implements SemanticParser {
}
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, CommonAgentTool commonAgentTool,
ParseResult parseResult) {
ParseResult parseResult) {
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
@@ -414,7 +431,7 @@ public class LLMS2QLParser implements SemanticParser {
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
@@ -450,7 +467,7 @@ public class LLMS2QLParser implements SemanticParser {
}
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
LLMParserConfig llmParserConfig) {
Set<String> results = semanticSchema.getDimensions(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.query.llm.s2ql;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
@@ -18,5 +19,5 @@ public class LLMResp {
private String schemaLinkStr;
private String correctorSql;
private Map<String, Double> sqlWeight;
}

View File

@@ -4,14 +4,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
@@ -42,13 +39,13 @@ public class S2QLQuery extends PluginSemanticQuery {
@Override
public QueryResult execute(User user) {
LLMResp llmResp = getLlmResp();
long startTime = System.currentTimeMillis();
QueryS2QLReq queryS2QLReq = getQueryS2QLReq(llmResp);
String querySql = parseInfo.getSqlInfo().getLogicSql();
QueryS2QLReq queryS2QLReq = getQueryS2QLReq(querySql);
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2QL(queryS2QLReq, user);
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
QueryResult queryResult = new QueryResult();
if (Objects.nonNull(queryResp)) {
@@ -67,14 +64,8 @@ public class S2QLQuery extends PluginSemanticQuery {
return queryResult;
}
private LLMResp getLlmResp() {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class);
return parseResult.getLlmResp();
}
private QueryS2QLReq getQueryS2QLReq(LLMResp llmResp) {
return QueryReqBuilder.buildS2QLReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
private QueryS2QLReq getQueryS2QLReq(String sql) {
return QueryReqBuilder.buildS2QLReq(sql, parseInfo.getModelId());
}
@Override
@@ -83,7 +74,7 @@ public class S2QLQuery extends PluginSemanticQuery {
try {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(getQueryS2QLReq(getLlmResp()))
.queryReq(getQueryS2QLReq(parseInfo.getSqlInfo().getLogicSql()))
.build();
return semanticInterpreter.explain(explainSqlReq, user);
} catch (Exception e) {

View File

@@ -19,14 +19,12 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
@@ -36,7 +34,6 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -339,10 +336,8 @@ public class QueryServiceImpl implements QueryService {
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class);
LLMResp llmResp = parseResult.getLlmResp();
String correctorSql = llmResp.getCorrectorSql();
String correctorSql = parseInfo.getSqlInfo().getLogicSql();
log.info("correctorSql before replacing:{}", correctorSql);
List<FilterExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
List<FilterExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
@@ -366,11 +361,6 @@ public class QueryServiceImpl implements QueryService {
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
llmResp.setCorrectorSql(correctorSql);
parseResult.setLlmResp(llmResp);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);
parseInfo.setProperties(properties);
parseInfo.getSqlInfo().setLogicSql(correctorSql);
semanticQuery.setParseInfo(parseInfo);
ExplainResp explain = semanticQuery.explain(user);