mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(semantic) support llm multiple parsing sql (#290)
This commit is contained in:
@@ -69,44 +69,61 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig);
|
||||
return;
|
||||
}
|
||||
//1.determine whether to skip this parser.
|
||||
if (SatisfactionChecker.check(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText());
|
||||
return;
|
||||
}
|
||||
try {
|
||||
//2.get modelId from queryCtx and chatCtx.
|
||||
Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId());
|
||||
if (Objects.isNull(modelId) || modelId <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
//3.get agent tool and determine whether to skip this parser.
|
||||
CommonAgentTool commonAgentTool = getParserTool(request, modelId);
|
||||
if (Objects.isNull(commonAgentTool)) {
|
||||
log.info("no tool in this agent, skip {}", LLMS2QLParser.class);
|
||||
return;
|
||||
}
|
||||
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig);
|
||||
LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig);
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
//5. get and update parserInfo and corrector sql
|
||||
Map<String, Double> sqlWeight = llmResp.getSqlWeight();
|
||||
|
||||
ParseResult parseResult = ParseResult.builder().request(request)
|
||||
.commonAgentTool(commonAgentTool).llmReq(llmReq).llmResp(llmResp).build();
|
||||
|
||||
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult);
|
||||
|
||||
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
|
||||
|
||||
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
|
||||
|
||||
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
|
||||
if (Objects.isNull(sqlWeight) || sqlWeight.size() <= 0) {
|
||||
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, llmResp.getSqlOutput());
|
||||
} else {
|
||||
sqlWeight.forEach((sql, weight) -> {
|
||||
addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, sql);
|
||||
});
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("LLMS2QLParser error", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void addParseInfo(QueryContext queryCtx, ParseResult parseResult, Long modelId,
|
||||
CommonAgentTool commonAgentTool, String sql) {
|
||||
|
||||
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult);
|
||||
|
||||
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, sql);
|
||||
|
||||
parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql());
|
||||
|
||||
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
|
||||
}
|
||||
|
||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
||||
@@ -172,7 +189,7 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FilterExpression> filterExpressions) {
|
||||
List<FilterExpression> filterExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FilterExpression expression : filterExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
@@ -229,7 +246,7 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||
}
|
||||
|
||||
@@ -257,7 +274,7 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, CommonAgentTool commonAgentTool,
|
||||
ParseResult parseResult) {
|
||||
ParseResult parseResult) {
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
|
||||
@@ -414,7 +431,7 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
|
||||
|
||||
@@ -450,7 +467,7 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
LLMParserConfig llmParserConfig) {
|
||||
Set<String> results = semanticSchema.getDimensions(modelId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.query.llm.s2ql;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@@ -18,5 +19,5 @@ public class LLMResp {
|
||||
|
||||
private String schemaLinkStr;
|
||||
|
||||
private String correctorSql;
|
||||
private Map<String, Double> sqlWeight;
|
||||
}
|
||||
|
||||
@@ -4,14 +4,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
@@ -42,13 +39,13 @@ public class S2QLQuery extends PluginSemanticQuery {
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) {
|
||||
LLMResp llmResp = getLlmResp();
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
QueryS2QLReq queryS2QLReq = getQueryS2QLReq(llmResp);
|
||||
String querySql = parseInfo.getSqlInfo().getLogicSql();
|
||||
QueryS2QLReq queryS2QLReq = getQueryS2QLReq(querySql);
|
||||
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2QL(queryS2QLReq, user);
|
||||
|
||||
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
|
||||
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
|
||||
|
||||
QueryResult queryResult = new QueryResult();
|
||||
if (Objects.nonNull(queryResp)) {
|
||||
@@ -67,14 +64,8 @@ public class S2QLQuery extends PluginSemanticQuery {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private LLMResp getLlmResp() {
|
||||
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
|
||||
ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class);
|
||||
return parseResult.getLlmResp();
|
||||
}
|
||||
|
||||
private QueryS2QLReq getQueryS2QLReq(LLMResp llmResp) {
|
||||
return QueryReqBuilder.buildS2QLReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
|
||||
private QueryS2QLReq getQueryS2QLReq(String sql) {
|
||||
return QueryReqBuilder.buildS2QLReq(sql, parseInfo.getModelId());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -83,7 +74,7 @@ public class S2QLQuery extends PluginSemanticQuery {
|
||||
try {
|
||||
explainSqlReq = ExplainSqlReq.builder()
|
||||
.queryTypeEnum(QueryTypeEnum.SQL)
|
||||
.queryReq(getQueryS2QLReq(getLlmResp()))
|
||||
.queryReq(getQueryS2QLReq(parseInfo.getSqlInfo().getLogicSql()))
|
||||
.build();
|
||||
return semanticInterpreter.explain(explainSqlReq, user);
|
||||
} catch (Exception e) {
|
||||
|
||||
@@ -19,14 +19,12 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
||||
@@ -36,7 +34,6 @@ import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.chat.service.StatisticsService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -339,10 +336,8 @@ public class QueryServiceImpl implements QueryService {
|
||||
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
|
||||
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
|
||||
ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class);
|
||||
LLMResp llmResp = parseResult.getLlmResp();
|
||||
String correctorSql = llmResp.getCorrectorSql();
|
||||
|
||||
String correctorSql = parseInfo.getSqlInfo().getLogicSql();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
List<FilterExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
|
||||
List<FilterExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
|
||||
@@ -366,11 +361,6 @@ public class QueryServiceImpl implements QueryService {
|
||||
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
|
||||
|
||||
log.info("correctorSql after replacing:{}", correctorSql);
|
||||
llmResp.setCorrectorSql(correctorSql);
|
||||
parseResult.setLlmResp(llmResp);
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.getSqlInfo().setLogicSql(correctorSql);
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
ExplainResp explain = semanticQuery.explain(user);
|
||||
|
||||
Reference in New Issue
Block a user