mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) rename logicSql to correctS2SQL and the rule does not perform corrector operation (#373)
This commit is contained in:
@@ -71,10 +71,10 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return result;
|
||||
}
|
||||
|
||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String logicSql) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(logicSql));
|
||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(logicSql));
|
||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(logicSql));
|
||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||
return;
|
||||
@@ -82,13 +82,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
needAddFields.removeAll(selectFields);
|
||||
needAddFields.remove(TimeDimensionEnum.DAY.getChName());
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(logicSql, new ArrayList<>(needAddFields));
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Long modelId = semanticParseInfo.getModel().getModel();
|
||||
|
||||
List<SchemaElement> metrics = getMetricElements(modelId);
|
||||
@@ -104,7 +104,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(logicSql, metricToAggregate);
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String logicSql = sqlInfo.getCorrectS2SQL();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
|
||||
@@ -47,7 +47,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
@@ -56,12 +56,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
return;
|
||||
}
|
||||
if (SqlParserSelectHelper.hasGroupBy(logicSql)) {
|
||||
log.info("not add group by ,exist group by in logicSql:{}", logicSql);
|
||||
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(field -> {
|
||||
@@ -71,7 +71,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(logicSql, groupByFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
|
||||
addAggregate(semanticParseInfo);
|
||||
}
|
||||
|
||||
@@ -44,13 +44,13 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
|
||||
@@ -12,15 +12,15 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& !CollectionUtils.isEmpty(selectFields)
|
||||
&& aggregateFields.size() == selectFields.size()) {
|
||||
return;
|
||||
}
|
||||
addFieldsToSelect(semanticParseInfo, logicSql);
|
||||
addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,38 +46,38 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
|
||||
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to logicSql :{}", queryFilter);
|
||||
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
|
||||
Expression expression = null;
|
||||
try {
|
||||
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
logicSql = SqlParserAddHelper.addWhere(logicSql, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
logicSql = SqlParserReplaceHelper.replaceFunction(logicSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(logicSql);
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
logicSql = SqlParserAddHelper.addParenthesisToWhere(logicSql);
|
||||
logicSql = SqlParserAddHelper.addWhere(logicSql, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
}
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
@@ -106,9 +106,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String logicSql = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(logicSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -41,26 +41,24 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
public String explain(User user) {
|
||||
ExplainSqlReq explainSqlReq = null;
|
||||
try {
|
||||
ExplainResp explain = null;
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
|
||||
if (StringUtils.isNotBlank(sqlInfo.getCorrectS2SQL())) {
|
||||
//sql
|
||||
QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(sqlInfo.getCorrectS2SQL(),
|
||||
parseInfo.getModelId());
|
||||
explainSqlReq = ExplainSqlReq.builder()
|
||||
.queryTypeEnum(QueryTypeEnum.SQL)
|
||||
.queryReq(queryS2SQLReq)
|
||||
.queryReq(QueryReqBuilder.buildS2SQLReq(
|
||||
sqlInfo.getCorrectS2SQL(), parseInfo.getModelId()
|
||||
))
|
||||
.build();
|
||||
explain = semanticInterpreter.explain(explainSqlReq, user);
|
||||
} else {
|
||||
//struct
|
||||
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
|
||||
explainSqlReq = ExplainSqlReq.builder()
|
||||
.queryTypeEnum(QueryTypeEnum.STRUCT)
|
||||
.queryReq(queryStructReq)
|
||||
.queryReq(QueryReqBuilder.buildStructReq(parseInfo))
|
||||
.build();
|
||||
explain = semanticInterpreter.explain(explainSqlReq, user);
|
||||
}
|
||||
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
|
||||
if (Objects.nonNull(explain)) {
|
||||
return explain.getSql();
|
||||
}
|
||||
|
||||
@@ -38,12 +38,16 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
|
||||
public void updateParseInfo(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
String logicSql = sqlInfo.getCorrectS2SQL();
|
||||
if (StringUtils.isBlank(logicSql)) {
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
if (StringUtils.isBlank(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
// if S2SQL equals correctS2SQL, than not update the parseInfo.
|
||||
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
|
||||
//set dataInfo
|
||||
try {
|
||||
if (!CollectionUtils.isEmpty(expressions)) {
|
||||
@@ -70,10 +74,9 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
//cannot use metrics in sql to override parse info
|
||||
//List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
//Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
||||
//parseInfo.setMetrics(metrics);
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
|
||||
parseInfo.setNativeQuery(false);
|
||||
|
||||
@@ -29,6 +29,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.QueryRanker;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
@@ -148,6 +149,10 @@ public class QueryServiceImpl implements QueryService {
|
||||
if (CollectionUtils.isNotEmpty(candidateQueries)) {
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
semanticQuery.initS2Sql(queryReq.getUser());
|
||||
// rule
|
||||
if (semanticQuery instanceof RuleSemanticQuery) {
|
||||
continue;
|
||||
}
|
||||
semanticCorrectors.stream().forEach(correction -> {
|
||||
correction.correct(queryReq, semanticQuery.getParseInfo());
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user