mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
Modify s2ql to s2sql and add LLMSemanticQuery containing S2SQLQuery and MetricInterpretQuery. (#350)
This commit is contained in:
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
public enum AgentToolType {
|
||||
RULE,
|
||||
LLM_S2QL,
|
||||
LLM_S2SQL,
|
||||
PLUGIN,
|
||||
INTERPRET
|
||||
}
|
||||
|
||||
@@ -39,8 +39,8 @@ public class OptimizationConfig {
|
||||
@Value("${candidate.threshold}")
|
||||
private Double candidateThreshold;
|
||||
|
||||
@Value("${user.s2ql.switch:false}")
|
||||
private boolean useS2qlSwitch;
|
||||
@Value("${user.s2SQL.switch:false}")
|
||||
private boolean useS2SqlSwitch;
|
||||
|
||||
@Value("${embedding.mapper.word.min:4}")
|
||||
private int embeddingMapperWordMin;
|
||||
|
||||
@@ -27,7 +27,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getLogicSql())) {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
return;
|
||||
}
|
||||
work(queryReq, semanticParseInfo);
|
||||
@@ -83,12 +83,12 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
needAddFields.removeAll(selectFields);
|
||||
needAddFields.remove(TimeDimensionEnum.DAY.getChName());
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(logicSql, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(replaceFields);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Long modelId = semanticParseInfo.getModel().getModel();
|
||||
|
||||
List<SchemaElement> metrics = getMetricElements(modelId);
|
||||
@@ -105,7 +105,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(logicSql, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(aggregateSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(Long modelId) {
|
||||
|
||||
@@ -15,14 +15,14 @@ public class GlobalAfterCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
|
||||
return;
|
||||
}
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(replaceSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package com.tencent.supersonic.chat.corrector;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
@@ -34,15 +34,15 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getLogicSql());
|
||||
sqlInfo.setLogicSql(replaceAlias);
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModelId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getLogicSql(), fieldNameMap);
|
||||
sqlInfo.setLogicSql(sql);
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
@@ -57,8 +57,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getLogicSql(), fieldValueToFieldNames);
|
||||
sqlInfo.setLogicSql(sql);
|
||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private List<ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
|
||||
@@ -91,7 +91,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getLogicSql(), filedNameToValueMap, false);
|
||||
sqlInfo.setLogicSql(sql);
|
||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String logicSql = sqlInfo.getLogicSql();
|
||||
String logicSql = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
|
||||
@@ -71,14 +71,14 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(SqlParserAddHelper.addGroupBy(logicSql, groupByFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(logicSql, groupByFields));
|
||||
|
||||
addAggregate(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getLogicSql());
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getLogicSql(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(havingSql);
|
||||
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2ql.S2QLDateHelper;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.S2SQLDateHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -46,7 +46,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
|
||||
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to logicSql :{}", queryFilter);
|
||||
@@ -57,27 +57,27 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
logicSql = SqlParserAddHelper.addWhere(logicSql, expression);
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
logicSql = SqlParserReplaceHelper.replaceFunction(logicSql);
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(logicSql);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
String currentDate = S2QLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
logicSql = SqlParserAddHelper.addParenthesisToWhere(logicSql);
|
||||
logicSql = SqlParserAddHelper.addWhere(logicSql, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
}
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
@@ -106,9 +106,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getLogicSql(),
|
||||
String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -20,7 +20,7 @@ public class SatisfactionChecker {
|
||||
// check all the parse info in candidate
|
||||
public static boolean check(QueryContext queryContext) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
if (query.getQueryMode().equals(S2QLQuery.QUERY_MODE)) {
|
||||
if (query.getQueryMode().equals(S2SQLQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
||||
|
||||
@@ -7,33 +7,32 @@ import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.HashMap;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class MetricInterpretParser implements SemanticParser {
|
||||
@@ -71,7 +70,7 @@ public class MetricInterpretParser implements SemanticParser {
|
||||
|
||||
private void buildQuery(Long modelId, QueryContext queryContext,
|
||||
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
PluginSemanticQuery metricInterpretQuery = QueryManager.createPluginQuery(MetricInterpretQuery.QUERY_MODE);
|
||||
LLMSemanticQuery metricInterpretQuery = QueryManager.createLLMQuery(MetricInterpretQuery.QUERY_MODE);
|
||||
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
|
||||
metrics, schemaElementMatches, toolName);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
@@ -12,9 +12,9 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
@@ -66,18 +66,18 @@ public class LLMRequestService {
|
||||
public boolean check(QueryContext queryCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig);
|
||||
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2SQLParser.class, llmParserConfig);
|
||||
return true;
|
||||
}
|
||||
if (SatisfactionChecker.check(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText());
|
||||
log.info("skip {}, queryText:{}", LLMS2SQLParser.class, request.getQueryText());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
|
||||
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2QL);
|
||||
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL);
|
||||
if (agentService.containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
@@ -89,7 +89,7 @@ public class LLMRequestService {
|
||||
|
||||
public CommonAgentTool getParserTool(QueryReq request, Long modelId) {
|
||||
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
|
||||
AgentToolType.LLM_S2QL);
|
||||
AgentToolType.LLM_S2SQL);
|
||||
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> modelIds = tool.getModelIds();
|
||||
@@ -131,7 +131,7 @@ public class LLMRequestService {
|
||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = S2QLDateHelper.getReferenceDate(modelId);
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(modelId);
|
||||
if (StringUtils.isEmpty(currentDate)) {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
@@ -20,11 +20,11 @@ import org.springframework.stereotype.Service;
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMResponseService {
|
||||
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2ql, Double weight) {
|
||||
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
||||
if (Objects.isNull(weight)) {
|
||||
weight = 0D;
|
||||
}
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE);
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
Long modelId = parseResult.getModelId();
|
||||
CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||
@@ -38,7 +38,7 @@ public class LLMResponseService {
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setS2QL(s2ql);
|
||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
@@ -1,19 +1,19 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class LLMS2QLParser implements SemanticParser {
|
||||
public class LLMS2SQLParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
@@ -32,7 +32,7 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
//3.get agent tool and determine whether to skip this parser.
|
||||
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelId);
|
||||
if (Objects.isNull(commonAgentTool)) {
|
||||
log.info("no tool in this agent, skip {}", LLMS2QLParser.class);
|
||||
log.info("no tool in this agent, skip {}", LLMS2SQLParser.class);
|
||||
return;
|
||||
}
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
@@ -62,7 +62,7 @@ public class LLMS2QLParser implements SemanticParser {
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("LLMS2QLParser error", e);
|
||||
log.error("parse", e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
@@ -11,7 +11,7 @@ import java.util.List;
|
||||
import java.util.Objects;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
public class S2QLDateHelper {
|
||||
public class S2SQLDateHelper {
|
||||
|
||||
public static String getReferenceDate(Long modelId) {
|
||||
String defaultDate = DateUtils.getBeforeDate(0);
|
||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.net.URI;
|
||||
@@ -101,7 +101,7 @@ public class FunctionBasedParser extends PluginParser {
|
||||
log.info("user decide Model:{}", modelId);
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
||||
if (S2QLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
if (S2SQLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
return false;
|
||||
}
|
||||
if (plugin.getParseModeConfig() == null) {
|
||||
|
||||
@@ -16,7 +16,7 @@ import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
@@ -40,10 +40,10 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
ExplainSqlReq explainSqlReq = null;
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
try {
|
||||
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId());
|
||||
QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(sqlInfo.getCorrectS2SQL(), parseInfo.getModelId());
|
||||
explainSqlReq = ExplainSqlReq.builder()
|
||||
.queryTypeEnum(QueryTypeEnum.SQL)
|
||||
.queryReq(queryS2QLReq)
|
||||
.queryReq(queryS2SQLReq)
|
||||
.build();
|
||||
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
|
||||
if (Objects.nonNull(explain)) {
|
||||
@@ -105,9 +105,9 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
protected void initS2SqlByStruct() {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(queryStructReq);
|
||||
QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq);
|
||||
parseInfo.getSqlInfo().setS2QL(queryS2QLReq.getSql());
|
||||
parseInfo.getSqlInfo().setLogicSql(queryS2QLReq.getSql());
|
||||
QueryS2SQLReq queryS2SQLReq = queryStructReq.convert(queryStructReq);
|
||||
parseInfo.getSqlInfo().setS2SQL(queryS2SQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(queryS2SQLReq.getSql());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package com.tencent.supersonic.chat.query;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -16,25 +16,45 @@ public class QueryManager {
|
||||
|
||||
private static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
|
||||
private static Map<String, PluginSemanticQuery> pluginQueryMap = new ConcurrentHashMap<>();
|
||||
private static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static void register(SemanticQuery query) {
|
||||
if (query instanceof RuleSemanticQuery) {
|
||||
ruleQueryMap.put(query.getQueryMode(), (RuleSemanticQuery) query);
|
||||
} else if (query instanceof PluginSemanticQuery) {
|
||||
pluginQueryMap.put(query.getQueryMode(), (PluginSemanticQuery) query);
|
||||
} else if (query instanceof LLMSemanticQuery) {
|
||||
llmQueryMap.put(query.getQueryMode(), (LLMSemanticQuery) query);
|
||||
}
|
||||
}
|
||||
|
||||
public static SemanticQuery createQuery(String queryMode) {
|
||||
if (containsRuleQuery(queryMode)) {
|
||||
return createRuleQuery(queryMode);
|
||||
} else {
|
||||
}
|
||||
if (containsPluginQuery(queryMode)) {
|
||||
return createPluginQuery(queryMode);
|
||||
}
|
||||
return createLLMQuery(queryMode);
|
||||
|
||||
}
|
||||
|
||||
public static RuleSemanticQuery createRuleQuery(String queryMode) {
|
||||
RuleSemanticQuery semanticQuery = ruleQueryMap.get(queryMode);
|
||||
return (RuleSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
|
||||
}
|
||||
|
||||
public static PluginSemanticQuery createPluginQuery(String queryMode) {
|
||||
PluginSemanticQuery semanticQuery = pluginQueryMap.get(queryMode);
|
||||
return (PluginSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
|
||||
}
|
||||
|
||||
public static LLMSemanticQuery createLLMQuery(String queryMode) {
|
||||
LLMSemanticQuery semanticQuery = llmQueryMap.get(queryMode);
|
||||
return (LLMSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
|
||||
}
|
||||
|
||||
private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) {
|
||||
if (Objects.isNull(semanticQuery)) {
|
||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
||||
}
|
||||
@@ -45,17 +65,6 @@ public class QueryManager {
|
||||
}
|
||||
}
|
||||
|
||||
public static PluginSemanticQuery createPluginQuery(String queryMode) {
|
||||
PluginSemanticQuery semanticQuery = pluginQueryMap.get(queryMode);
|
||||
if (Objects.isNull(semanticQuery)) {
|
||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
||||
}
|
||||
try {
|
||||
return semanticQuery.getClass().getDeclaredConstructor().newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
||||
}
|
||||
}
|
||||
public static boolean containsRuleQuery(String queryMode) {
|
||||
if (queryMode == null) {
|
||||
return false;
|
||||
@@ -77,7 +86,7 @@ public class QueryManager {
|
||||
return ruleQueryMap.get(queryMode) instanceof EntitySemanticQuery;
|
||||
}
|
||||
|
||||
public static boolean isPluginQuery(String queryMode) {
|
||||
public static boolean containsPluginQuery(String queryMode) {
|
||||
return queryMode != null && pluginQueryMap.containsKey(queryMode);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.chat.query.llm;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.query.BaseSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public abstract class LLMSemanticQuery extends BaseSemanticQuery {
|
||||
|
||||
@Override
|
||||
public void initS2Sql(User user) {
|
||||
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.query.llm.interpret;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class LLmAnswerReq {
|
||||
public class LLMAnswerReq {
|
||||
|
||||
private String queryText;
|
||||
|
||||
@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.query.llm.interpret;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class LLmAnswerResp {
|
||||
public class LLMAnswerResp {
|
||||
private String assistantMessage;
|
||||
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
@@ -35,7 +35,7 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricInterpretQuery extends PluginSemanticQuery {
|
||||
public class MetricInterpretQuery extends LLMSemanticQuery {
|
||||
|
||||
|
||||
public static final String QUERY_MODE = "METRIC_INTERPRET";
|
||||
@@ -56,9 +56,9 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
|
||||
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
if (optimizationConfig.isUseS2qlSwitch()) {
|
||||
queryStructReq.setS2QL(parseInfo.getSqlInfo().getS2QL());
|
||||
queryStructReq.setS2QL(parseInfo.getSqlInfo().getQuerySql());
|
||||
if (optimizationConfig.isUseS2SqlSwitch()) {
|
||||
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getS2SQL());
|
||||
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getQuerySQL());
|
||||
}
|
||||
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user);
|
||||
@@ -151,12 +151,12 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
|
||||
|
||||
public String fetchInterpret(String queryText, String dataText) {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
|
||||
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
|
||||
lLmAnswerReq.setQueryText(queryText);
|
||||
lLmAnswerReq.setPluginOutput(dataText);
|
||||
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
|
||||
JSONObject.toJSONString(lLmAnswerReq));
|
||||
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
|
||||
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
|
||||
if (lLmAnswerResp != null) {
|
||||
return lLmAnswerResp.getAssistantMessage();
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.query.llm.s2sql;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.query.llm.s2sql;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
@@ -6,12 +6,12 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -21,12 +21,12 @@ import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class S2QLQuery extends PluginSemanticQuery {
|
||||
public class S2SQLQuery extends LLMSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "LLM_S2QL";
|
||||
public static final String QUERY_MODE = "LLM_S2SQL";
|
||||
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
|
||||
public S2QLQuery() {
|
||||
public S2SQLQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@@ -39,11 +39,11 @@ public class S2QLQuery extends PluginSemanticQuery {
|
||||
public QueryResult execute(User user) {
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
String querySql = parseInfo.getSqlInfo().getLogicSql();
|
||||
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(querySql, parseInfo.getModelId());
|
||||
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2QL(queryS2QLReq, user);
|
||||
String querySql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getModelId());
|
||||
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2SQL(queryS2SQLReq, user);
|
||||
|
||||
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
|
||||
log.info("queryByS2SQL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
|
||||
|
||||
QueryResult queryResult = new QueryResult();
|
||||
if (Objects.nonNull(queryResp)) {
|
||||
@@ -65,6 +65,6 @@ public class S2QLQuery extends PluginSemanticQuery {
|
||||
@Override
|
||||
public void initS2Sql(User user) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setLogicSql(sqlInfo.getS2QL());
|
||||
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL());
|
||||
}
|
||||
}
|
||||
@@ -200,9 +200,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
if (optimizationConfig.isUseS2qlSwitch()) {
|
||||
queryStructReq.setS2QL(parseInfo.getSqlInfo().getS2QL());
|
||||
queryStructReq.setLogicSql(parseInfo.getSqlInfo().getLogicSql());
|
||||
if (optimizationConfig.isUseS2SqlSwitch()) {
|
||||
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getS2SQL());
|
||||
queryStructReq.setCorrectS2SQL(parseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
}
|
||||
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.List;
|
||||
@@ -24,7 +24,7 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
|
||||
return;
|
||||
}
|
||||
String queryMode = semanticParseInfo.getQueryMode();
|
||||
if (QueryManager.isPluginQuery(queryMode) && !S2QLQuery.QUERY_MODE.equals(queryMode)) {
|
||||
if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
|
||||
return;
|
||||
}
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.List;
|
||||
@@ -18,22 +18,22 @@ public class EntityInfoParseResponder implements ParseResponder {
|
||||
|
||||
@Override
|
||||
public void fillResponse(ParseResp parseResp, QueryContext queryContext,
|
||||
List<ChatParseDO> chatParseDOS) {
|
||||
List<ChatParseDO> chatParseDOS) {
|
||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||
if (CollectionUtils.isEmpty(selectedParses)) {
|
||||
return;
|
||||
}
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
selectedParses.forEach(parseInfo -> {
|
||||
if (QueryManager.isPluginQuery(parseInfo.getQueryMode())
|
||||
&& !S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
|
||||
return;
|
||||
}
|
||||
//1. set entity info
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser());
|
||||
if (QueryManager.isEntityQuery(parseInfo.getQueryMode())
|
||||
|| QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
|
||||
if (QueryManager.isEntityQuery(queryMode)
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
}
|
||||
//2. set native value
|
||||
|
||||
@@ -68,7 +68,7 @@ public class SqlInfoParseResponder implements ParseResponder {
|
||||
if (StringUtils.isBlank(explainSql)) {
|
||||
return;
|
||||
}
|
||||
parseInfo.getSqlInfo().setQuerySql(explainSql);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explainSql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
|
||||
public void updateParseInfo(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
String logicSql = sqlInfo.getLogicSql();
|
||||
String logicSql = sqlInfo.getCorrectS2SQL();
|
||||
if (StringUtils.isBlank(logicSql)) {
|
||||
return;
|
||||
}
|
||||
@@ -103,20 +103,20 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
|
||||
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
|
||||
parseInfo.setNativeQuery(false);
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else {
|
||||
parseInfo.setNativeQuery(true);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(
|
||||
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
|
||||
|
||||
@@ -27,7 +27,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
@@ -291,11 +291,11 @@ public class QueryServiceImpl implements QueryService {
|
||||
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||
|
||||
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
if (S2SQLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
|
||||
|
||||
String correctorSql = parseInfo.getSqlInfo().getLogicSql();
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
List<FilterExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
|
||||
@@ -321,11 +321,11 @@ public class QueryServiceImpl implements QueryService {
|
||||
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
|
||||
|
||||
log.info("correctorSql after replacing:{}", correctorSql);
|
||||
parseInfo.getSqlInfo().setLogicSql(correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
String explainSql = semanticQuery.explain(user);
|
||||
if (StringUtils.isNotBlank(explainSql)) {
|
||||
parseInfo.getSqlInfo().setQuerySql(explainSql);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explainSql);
|
||||
}
|
||||
}
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
@@ -522,7 +522,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
|
||||
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData, ChatParseDO chatParseDO) {
|
||||
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
|
||||
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
if (S2SQLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
return parseInfo;
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
|
||||
|
||||
@@ -9,7 +9,7 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.llm.s2ql.ModelResolver;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
||||
@@ -20,7 +20,7 @@ public class ComponentFactory {
|
||||
|
||||
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
|
||||
private static List<SemanticParser> semanticParsers = new ArrayList<>();
|
||||
private static List<SemanticCorrector> s2QLCorrections = new ArrayList<>();
|
||||
private static List<SemanticCorrector> s2SQLCorrections = new ArrayList<>();
|
||||
private static SemanticInterpreter semanticInterpreter;
|
||||
private static List<ParseResponder> parseResponders = new ArrayList<>();
|
||||
private static List<ExecuteResponder> executeResponders = new ArrayList<>();
|
||||
@@ -35,8 +35,8 @@ public class ComponentFactory {
|
||||
}
|
||||
|
||||
public static List<SemanticCorrector> getSqlCorrections() {
|
||||
return CollectionUtils.isEmpty(s2QLCorrections) ? init(SemanticCorrector.class,
|
||||
s2QLCorrections) : s2QLCorrections;
|
||||
return CollectionUtils.isEmpty(s2SQLCorrections) ? init(SemanticCorrector.class,
|
||||
s2SQLCorrections) : s2SQLCorrections;
|
||||
}
|
||||
|
||||
public static List<ParseResponder> getParseResponders() {
|
||||
|
||||
@@ -13,7 +13,7 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
@@ -124,19 +124,19 @@ public class QueryReqBuilder {
|
||||
}
|
||||
|
||||
/**
|
||||
* convert to QueryS2QLReq
|
||||
* convert to QueryS2SQLReq
|
||||
*
|
||||
* @param querySql
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
public static QueryS2QLReq buildS2QLReq(String querySql, Long modelId) {
|
||||
QueryS2QLReq queryS2QLReq = new QueryS2QLReq();
|
||||
public static QueryS2SQLReq buildS2SQLReq(String querySql, Long modelId) {
|
||||
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
|
||||
if (Objects.nonNull(querySql)) {
|
||||
queryS2QLReq.setSql(querySql);
|
||||
queryS2SQLReq.setSql(querySql);
|
||||
}
|
||||
queryS2QLReq.setModelId(modelId);
|
||||
return queryS2QLReq;
|
||||
queryS2SQLReq.setModelId(modelId);
|
||||
return queryS2SQLReq;
|
||||
}
|
||||
|
||||
private static List<Aggregator> getAggregatorByMetric(AggregateTypeEnum aggregateType, SchemaElement metric) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2ql;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@@ -14,7 +14,7 @@ import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
class LLMS2QLParserTest {
|
||||
class LLMS2SQLParserTest {
|
||||
|
||||
@Test
|
||||
void setFilter() {
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateModeUtils;
|
||||
import com.tencent.supersonic.common.util.SqlFilterUtils;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
@@ -25,7 +25,7 @@ import org.mockito.Mockito;
|
||||
class QueryReqBuilderTest {
|
||||
|
||||
@Test
|
||||
void buildS2QLReq() {
|
||||
void buildS2SQLReq() {
|
||||
init();
|
||||
QueryStructReq queryStructReq = new QueryStructReq();
|
||||
queryStructReq.setModelId(1L);
|
||||
@@ -50,17 +50,17 @@ class QueryReqBuilderTest {
|
||||
orders.add(order);
|
||||
queryStructReq.setOrders(orders);
|
||||
|
||||
QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq);
|
||||
QueryS2SQLReq queryS2SQLReq = queryStructReq.convert(queryStructReq);
|
||||
Assert.assertEquals(
|
||||
"SELECT department, SUM(pv) FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
|
||||
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2QLReq.getSql());
|
||||
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2SQLReq.getSql());
|
||||
|
||||
queryStructReq.setNativeQuery(true);
|
||||
queryS2QLReq = queryStructReq.convert(queryStructReq);
|
||||
queryS2SQLReq = queryStructReq.convert(queryStructReq);
|
||||
Assert.assertEquals(
|
||||
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
|
||||
+ "ORDER BY uv LIMIT 2000",
|
||||
queryS2QLReq.getSql());
|
||||
queryS2SQLReq.getSql());
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user