Modify s2ql to s2sql and add LLMSemanticQuery containing S2SQLQuery and MetricInterpretQuery. (#350)

This commit is contained in:
lexluo09
2023-11-09 21:59:27 +08:00
committed by GitHub
parent acee0a36da
commit 7d33c49db8
90 changed files with 379 additions and 357 deletions

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.agent.tool;
public enum AgentToolType {
RULE,
LLM_S2QL,
LLM_S2SQL,
PLUGIN,
INTERPRET
}

View File

@@ -39,8 +39,8 @@ public class OptimizationConfig {
@Value("${candidate.threshold}")
private Double candidateThreshold;
@Value("${user.s2ql.switch:false}")
private boolean useS2qlSwitch;
@Value("${user.s2SQL.switch:false}")
private boolean useS2SqlSwitch;
@Value("${embedding.mapper.word.min:4}")
private int embeddingMapperWordMin;

View File

@@ -27,7 +27,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
try {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getLogicSql())) {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
return;
}
work(queryReq, semanticParseInfo);
@@ -83,12 +83,12 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
needAddFields.removeAll(selectFields);
needAddFields.remove(TimeDimensionEnum.DAY.getChName());
String replaceFields = SqlParserAddHelper.addFieldsToSelect(logicSql, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setLogicSql(replaceFields);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
}
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
Long modelId = semanticParseInfo.getModel().getModel();
List<SchemaElement> metrics = getMetricElements(modelId);
@@ -105,7 +105,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return;
}
String aggregateSql = SqlParserAddHelper.addAggregateToField(logicSql, metricToAggregate);
semanticParseInfo.getSqlInfo().setLogicSql(aggregateSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(Long modelId) {

View File

@@ -15,14 +15,14 @@ public class GlobalAfterCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
return;
}
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
semanticParseInfo.getSqlInfo().setLogicSql(replaceSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
}
return;
}

View File

@@ -3,9 +3,9 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
@@ -34,15 +34,15 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getLogicSql());
sqlInfo.setLogicSql(replaceAlias);
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
sqlInfo.setCorrectS2SQL(replaceAlias);
}
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModelId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getLogicSql(), fieldNameMap);
sqlInfo.setLogicSql(sql);
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);
}
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
@@ -57,8 +57,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getLogicSql(), fieldValueToFieldNames);
sqlInfo.setLogicSql(sql);
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
sqlInfo.setCorrectS2SQL(sql);
}
private List<ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
@@ -91,7 +91,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
)));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getLogicSql(), filedNameToValueMap, false);
sqlInfo.setLogicSql(sql);
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql);
}
}

View File

@@ -31,7 +31,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String logicSql = sqlInfo.getLogicSql();
String logicSql = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
//add alias field name
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
@@ -71,14 +71,14 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return true;
})
.collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setLogicSql(SqlParserAddHelper.addGroupBy(logicSql, groupByFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(logicSql, groupByFields));
addAggregate(semanticParseInfo);
}
private void addAggregate(SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getLogicSql());
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}

View File

@@ -28,8 +28,8 @@ public class HavingCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(metrics)) {
return;
}
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getLogicSql(), metrics);
semanticParseInfo.getSqlInfo().setLogicSql(havingSql);
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
}
}

View File

@@ -12,7 +12,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.llm.s2ql.S2QLDateHelper;
import com.tencent.supersonic.chat.parser.llm.s2sql.S2SQLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -46,7 +46,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to logicSql :{}", queryFilter);
@@ -57,27 +57,27 @@ public class WhereCorrector extends BaseSemanticCorrector {
log.error("parseCondExpression", e);
}
logicSql = SqlParserAddHelper.addWhere(logicSql, expression);
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
}
}
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
logicSql = SqlParserReplaceHelper.replaceFunction(logicSql);
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
}
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getLogicSql();
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(logicSql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getChName())) {
String currentDate = S2QLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
if (StringUtils.isNotBlank(currentDate)) {
logicSql = SqlParserAddHelper.addParenthesisToWhere(logicSql);
logicSql = SqlParserAddHelper.addWhere(logicSql, TimeDimensionEnum.DAY.getChName(), currentDate);
}
}
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
}
private String getQueryFilter(QueryFilters queryFilters) {
@@ -106,9 +106,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getLogicSql(),
String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setLogicSql(logicSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
return;
}

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
@@ -20,7 +20,7 @@ public class SatisfactionChecker {
// check all the parse info in candidate
public static boolean check(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (query.getQueryMode().equals(S2QLQuery.QUERY_MODE)) {
if (query.getQueryMode().equals(S2SQLQuery.QUERY_MODE)) {
continue;
}
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {

View File

@@ -7,33 +7,32 @@ import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashMap;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class MetricInterpretParser implements SemanticParser {
@@ -71,7 +70,7 @@ public class MetricInterpretParser implements SemanticParser {
private void buildQuery(Long modelId, QueryContext queryContext,
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
PluginSemanticQuery metricInterpretQuery = QueryManager.createPluginQuery(MetricInterpretQuery.QUERY_MODE);
LLMSemanticQuery metricInterpretQuery = QueryManager.createLLMQuery(MetricInterpretQuery.QUERY_MODE);
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
metrics, schemaElementMatches, toolName);

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
@@ -12,9 +12,9 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
@@ -66,18 +66,18 @@ public class LLMRequestService {
public boolean check(QueryContext queryCtx) {
QueryReq request = queryCtx.getRequest();
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig);
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2SQLParser.class, llmParserConfig);
return true;
}
if (SatisfactionChecker.check(queryCtx)) {
log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText());
log.info("skip {}, queryText:{}", LLMS2SQLParser.class, request.getQueryText());
return true;
}
return false;
}
public Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2QL);
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL);
if (agentService.containsAllModel(distinctModelIds)) {
distinctModelIds = new HashSet<>();
}
@@ -89,7 +89,7 @@ public class LLMRequestService {
public CommonAgentTool getParserTool(QueryReq request, Long modelId) {
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
AgentToolType.LLM_S2QL);
AgentToolType.LLM_S2SQL);
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream()
.filter(tool -> {
List<Long> modelIds = tool.getModelIds();
@@ -131,7 +131,7 @@ public class LLMRequestService {
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = S2QLDateHelper.getReferenceDate(modelId);
String currentDate = S2SQLDateHelper.getReferenceDate(modelId);
if (StringUtils.isEmpty(currentDate)) {
currentDate = DateUtils.getBeforeDate(0);
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -20,11 +20,11 @@ import org.springframework.stereotype.Service;
@Slf4j
@Service
public class LLMResponseService {
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2ql, Double weight) {
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
if (Objects.isNull(weight)) {
weight = 0D;
}
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE);
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
Long modelId = parseResult.getModelId();
CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool();
@@ -38,7 +38,7 @@ public class LLMResponseService {
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2QL(s2ql);
parseInfo.getSqlInfo().setS2SQL(s2SQL);
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();

View File

@@ -1,19 +1,19 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class LLMS2QLParser implements SemanticParser {
public class LLMS2SQLParser implements SemanticParser {
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
@@ -32,7 +32,7 @@ public class LLMS2QLParser implements SemanticParser {
//3.get agent tool and determine whether to skip this parser.
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelId);
if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMS2QLParser.class);
log.info("no tool in this agent, skip {}", LLMS2SQLParser.class);
return;
}
//4.construct a request, call the API for the large model, and retrieve the results.
@@ -62,7 +62,7 @@ public class LLMS2QLParser implements SemanticParser {
}
} catch (Exception e) {
log.error("LLMS2QLParser error", e);
log.error("parse", e);
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.api.pojo.ChatContext;

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
@@ -11,7 +11,7 @@ import java.util.List;
import java.util.Objects;
import org.apache.commons.collections.CollectionUtils;
public class S2QLDateHelper {
public class S2SQLDateHelper {
public static String getReferenceDate(Long modelId) {
String defaultDate = DateUtils.getBeforeDate(0);

View File

@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
@@ -101,7 +101,7 @@ public class FunctionBasedParser extends PluginParser {
log.info("user decide Model:{}", modelId);
List<Plugin> plugins = getPluginList(queryContext);
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
if (S2QLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
if (S2SQLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
return false;
}
if (plugin.getParseModeConfig() == null) {

View File

@@ -16,7 +16,7 @@ import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.io.Serializable;
import java.util.List;
@@ -40,10 +40,10 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
ExplainSqlReq explainSqlReq = null;
SqlInfo sqlInfo = parseInfo.getSqlInfo();
try {
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(sqlInfo.getLogicSql(), parseInfo.getModelId());
QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(sqlInfo.getCorrectS2SQL(), parseInfo.getModelId());
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(queryS2QLReq)
.queryReq(queryS2SQLReq)
.build();
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
if (Objects.nonNull(explain)) {
@@ -105,9 +105,9 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
protected void initS2SqlByStruct() {
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(queryStructReq);
QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq);
parseInfo.getSqlInfo().setS2QL(queryS2QLReq.getSql());
parseInfo.getSqlInfo().setLogicSql(queryS2QLReq.getSql());
QueryS2SQLReq queryS2SQLReq = queryStructReq.convert(queryStructReq);
parseInfo.getSqlInfo().setS2SQL(queryS2SQLReq.getSql());
parseInfo.getSqlInfo().setCorrectS2SQL(queryS2SQLReq.getSql());
}
}

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -16,25 +16,45 @@ public class QueryManager {
private static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
private static Map<String, PluginSemanticQuery> pluginQueryMap = new ConcurrentHashMap<>();
private static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
public static void register(SemanticQuery query) {
if (query instanceof RuleSemanticQuery) {
ruleQueryMap.put(query.getQueryMode(), (RuleSemanticQuery) query);
} else if (query instanceof PluginSemanticQuery) {
pluginQueryMap.put(query.getQueryMode(), (PluginSemanticQuery) query);
} else if (query instanceof LLMSemanticQuery) {
llmQueryMap.put(query.getQueryMode(), (LLMSemanticQuery) query);
}
}
public static SemanticQuery createQuery(String queryMode) {
if (containsRuleQuery(queryMode)) {
return createRuleQuery(queryMode);
} else {
}
if (containsPluginQuery(queryMode)) {
return createPluginQuery(queryMode);
}
return createLLMQuery(queryMode);
}
public static RuleSemanticQuery createRuleQuery(String queryMode) {
RuleSemanticQuery semanticQuery = ruleQueryMap.get(queryMode);
return (RuleSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
}
public static PluginSemanticQuery createPluginQuery(String queryMode) {
PluginSemanticQuery semanticQuery = pluginQueryMap.get(queryMode);
return (PluginSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
}
public static LLMSemanticQuery createLLMQuery(String queryMode) {
LLMSemanticQuery semanticQuery = llmQueryMap.get(queryMode);
return (LLMSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
}
private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) {
if (Objects.isNull(semanticQuery)) {
throw new RuntimeException("no supported queryMode :" + queryMode);
}
@@ -45,17 +65,6 @@ public class QueryManager {
}
}
public static PluginSemanticQuery createPluginQuery(String queryMode) {
PluginSemanticQuery semanticQuery = pluginQueryMap.get(queryMode);
if (Objects.isNull(semanticQuery)) {
throw new RuntimeException("no supported queryMode :" + queryMode);
}
try {
return semanticQuery.getClass().getDeclaredConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException("no supported queryMode :" + queryMode);
}
}
public static boolean containsRuleQuery(String queryMode) {
if (queryMode == null) {
return false;
@@ -77,7 +86,7 @@ public class QueryManager {
return ruleQueryMap.get(queryMode) instanceof EntitySemanticQuery;
}
public static boolean isPluginQuery(String queryMode) {
public static boolean containsPluginQuery(String queryMode) {
return queryMode != null && pluginQueryMap.containsKey(queryMode);
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.chat.query.llm;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.query.BaseSemanticQuery;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public abstract class LLMSemanticQuery extends BaseSemanticQuery {
@Override
public void initS2Sql(User user) {
}
}

View File

@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.query.llm.interpret;
import lombok.Data;
@Data
public class LLmAnswerReq {
public class LLMAnswerReq {
private String queryText;

View File

@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.query.llm.interpret;
import lombok.Data;
@Data
public class LLmAnswerResp {
public class LLMAnswerResp {
private String assistantMessage;
}

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Aggregator;
@@ -35,7 +35,7 @@ import org.springframework.util.CollectionUtils;
@Slf4j
@Component
public class MetricInterpretQuery extends PluginSemanticQuery {
public class MetricInterpretQuery extends LLMSemanticQuery {
public static final String QUERY_MODE = "METRIC_INTERPRET";
@@ -56,9 +56,9 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (optimizationConfig.isUseS2qlSwitch()) {
queryStructReq.setS2QL(parseInfo.getSqlInfo().getS2QL());
queryStructReq.setS2QL(parseInfo.getSqlInfo().getQuerySql());
if (optimizationConfig.isUseS2SqlSwitch()) {
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getS2SQL());
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getQuerySQL());
}
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user);
@@ -151,12 +151,12 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
public String fetchInterpret(String queryText, String dataText) {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
lLmAnswerReq.setQueryText(queryText);
lLmAnswerReq.setPluginOutput(dataText);
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
JSONObject.toJSONString(lLmAnswerReq));
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistantMessage();
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.llm.s2ql;
package com.tencent.supersonic.chat.query.llm.s2sql;
import java.util.List;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.llm.s2ql;
package com.tencent.supersonic.chat.query.llm.s2sql;
import java.util.List;
import java.util.Map;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.llm.s2ql;
package com.tencent.supersonic.chat.query.llm.s2sql;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
@@ -6,12 +6,12 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -21,12 +21,12 @@ import org.springframework.stereotype.Component;
@Slf4j
@Component
public class S2QLQuery extends PluginSemanticQuery {
public class S2SQLQuery extends LLMSemanticQuery {
public static final String QUERY_MODE = "LLM_S2QL";
public static final String QUERY_MODE = "LLM_S2SQL";
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
public S2QLQuery() {
public S2SQLQuery() {
QueryManager.register(this);
}
@@ -39,11 +39,11 @@ public class S2QLQuery extends PluginSemanticQuery {
public QueryResult execute(User user) {
long startTime = System.currentTimeMillis();
String querySql = parseInfo.getSqlInfo().getLogicSql();
QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(querySql, parseInfo.getModelId());
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2QL(queryS2QLReq, user);
String querySql = parseInfo.getSqlInfo().getCorrectS2SQL();
QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getModelId());
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2SQL(queryS2SQLReq, user);
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
log.info("queryByS2SQL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
QueryResult queryResult = new QueryResult();
if (Objects.nonNull(queryResp)) {
@@ -65,6 +65,6 @@ public class S2QLQuery extends PluginSemanticQuery {
@Override
public void initS2Sql(User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setLogicSql(sqlInfo.getS2QL());
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL());
}
}

View File

@@ -200,9 +200,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
QueryStructReq queryStructReq = convertQueryStruct();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (optimizationConfig.isUseS2qlSwitch()) {
queryStructReq.setS2QL(parseInfo.getSqlInfo().getS2QL());
queryStructReq.setLogicSql(parseInfo.getSqlInfo().getLogicSql());
if (optimizationConfig.isUseS2SqlSwitch()) {
queryStructReq.setS2SQL(parseInfo.getSqlInfo().getS2SQL());
queryStructReq.setCorrectS2SQL(parseInfo.getSqlInfo().getCorrectS2SQL());
}
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user);

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
@@ -24,7 +24,7 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
return;
}
String queryMode = semanticParseInfo.getQueryMode();
if (QueryManager.isPluginQuery(queryMode) && !S2QLQuery.QUERY_MODE.equals(queryMode)) {
if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
return;
}
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
@@ -18,22 +18,22 @@ public class EntityInfoParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext,
List<ChatParseDO> chatParseDOS) {
List<ChatParseDO> chatParseDOS) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
if (CollectionUtils.isEmpty(selectedParses)) {
return;
}
QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> {
if (QueryManager.isPluginQuery(parseInfo.getQueryMode())
&& !S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
return;
}
//1. set entity info
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser());
if (QueryManager.isEntityQuery(parseInfo.getQueryMode())
|| QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
if (QueryManager.isEntityQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo);
}
//2. set native value

View File

@@ -68,7 +68,7 @@ public class SqlInfoParseResponder implements ParseResponder {
if (StringUtils.isBlank(explainSql)) {
return;
}
parseInfo.getSqlInfo().setQuerySql(explainSql);
parseInfo.getSqlInfo().setQuerySQL(explainSql);
}
}

View File

@@ -73,7 +73,7 @@ public class ParserInfoServiceImpl implements ParseInfoService {
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String logicSql = sqlInfo.getLogicSql();
String logicSql = sqlInfo.getCorrectS2SQL();
if (StringUtils.isBlank(logicSql)) {
return;
}
@@ -103,20 +103,20 @@ public class ParserInfoServiceImpl implements ParseInfoService {
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
parseInfo.setNativeQuery(false);
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setNativeQuery(true);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));

View File

@@ -27,7 +27,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import com.tencent.supersonic.chat.service.ChatService;
@@ -291,11 +291,11 @@ public class QueryServiceImpl implements QueryService {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
if (S2SQLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getLogicSql();
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FilterExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
@@ -321,11 +321,11 @@ public class QueryServiceImpl implements QueryService {
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
parseInfo.getSqlInfo().setLogicSql(correctorSql);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
String explainSql = semanticQuery.explain(user);
if (StringUtils.isNotBlank(explainSql)) {
parseInfo.getSqlInfo().setQuerySql(explainSql);
parseInfo.getSqlInfo().setQuerySQL(explainSql);
}
}
semanticQuery.setParseInfo(parseInfo);
@@ -522,7 +522,7 @@ public class QueryServiceImpl implements QueryService {
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData, ChatParseDO chatParseDO) {
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
if (S2SQLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
}
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {

View File

@@ -9,7 +9,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.chat.parser.llm.s2ql.ModelResolver;
import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
@@ -20,7 +20,7 @@ public class ComponentFactory {
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
private static List<SemanticParser> semanticParsers = new ArrayList<>();
private static List<SemanticCorrector> s2QLCorrections = new ArrayList<>();
private static List<SemanticCorrector> s2SQLCorrections = new ArrayList<>();
private static SemanticInterpreter semanticInterpreter;
private static List<ParseResponder> parseResponders = new ArrayList<>();
private static List<ExecuteResponder> executeResponders = new ArrayList<>();
@@ -35,8 +35,8 @@ public class ComponentFactory {
}
public static List<SemanticCorrector> getSqlCorrections() {
return CollectionUtils.isEmpty(s2QLCorrections) ? init(SemanticCorrector.class,
s2QLCorrections) : s2QLCorrections;
return CollectionUtils.isEmpty(s2SQLCorrections) ? init(SemanticCorrector.class,
s2SQLCorrections) : s2SQLCorrections;
}
public static List<ParseResponder> getParseResponders() {

View File

@@ -13,7 +13,7 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.time.LocalDate;
import java.util.ArrayList;
@@ -124,19 +124,19 @@ public class QueryReqBuilder {
}
/**
* convert to QueryS2QLReq
* convert to QueryS2SQLReq
*
* @param querySql
* @param modelId
* @return
*/
public static QueryS2QLReq buildS2QLReq(String querySql, Long modelId) {
QueryS2QLReq queryS2QLReq = new QueryS2QLReq();
public static QueryS2SQLReq buildS2SQLReq(String querySql, Long modelId) {
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
if (Objects.nonNull(querySql)) {
queryS2QLReq.setSql(querySql);
queryS2SQLReq.setSql(querySql);
}
queryS2QLReq.setModelId(modelId);
return queryS2QLReq;
queryS2SQLReq.setModelId(modelId);
return queryS2SQLReq;
}
private static List<Aggregator> getAggregatorByMetric(AggregateTypeEnum aggregateType, SchemaElement metric) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.s2ql;
package com.tencent.supersonic.chat.parser.llm.s2sql;
import static org.mockito.Mockito.when;
@@ -14,7 +14,7 @@ import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class LLMS2QLParserTest {
class LLMS2SQLParserTest {
@Test
void setFilter() {

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateModeUtils;
import com.tencent.supersonic.common.util.SqlFilterUtils;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.ArrayList;
import java.util.Arrays;
@@ -25,7 +25,7 @@ import org.mockito.Mockito;
class QueryReqBuilderTest {
@Test
void buildS2QLReq() {
void buildS2SQLReq() {
init();
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setModelId(1L);
@@ -50,17 +50,17 @@ class QueryReqBuilderTest {
orders.add(order);
queryStructReq.setOrders(orders);
QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq);
QueryS2SQLReq queryS2SQLReq = queryStructReq.convert(queryStructReq);
Assert.assertEquals(
"SELECT department, SUM(pv) FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2QLReq.getSql());
+ "GROUP BY department ORDER BY uv LIMIT 2000", queryS2SQLReq.getSql());
queryStructReq.setNativeQuery(true);
queryS2QLReq = queryStructReq.convert(queryStructReq);
queryS2SQLReq = queryStructReq.convert(queryStructReq);
Assert.assertEquals(
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "ORDER BY uv LIMIT 2000",
queryS2QLReq.getSql());
queryS2SQLReq.getSql());
}