diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index 8c78b3e81..de50ec147 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -12,6 +12,7 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; @@ -67,6 +68,7 @@ public class WhereCorrector extends BaseSemanticCorrector { private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL); + correctS2SQL = SqlParserRemoveHelper.removeNumberCondition(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); } @@ -77,8 +79,8 @@ public class WhereCorrector extends BaseSemanticCorrector { String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId()); if (StringUtils.isNotBlank(currentDate)) { correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL); - correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), - currentDate); + correctS2SQL = SqlParserAddHelper.addWhere( + correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate); } } semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index ed8c58d95..c2e47eb53 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -273,36 +273,21 @@ public class QueryServiceImpl implements QueryService { SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); semanticQuery.setParseInfo(parseInfo); - if (S2SQLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { - Map> filedNameToValueMap = new HashMap<>(); - Map> havingFiledNameToValueMap = new HashMap<>(); - + List metrics = queryData.getMetrics().stream().map(o -> o.getName()).collect(Collectors.toList()); + List fields = new ArrayList<>(); + if (Objects.nonNull(parseInfo.getSqlInfo()) + && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) { String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); - log.info("correctorSql before replacing:{}", correctorSql); - // get where filter and having filter - List whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql); - List havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql); - List addWhereConditions = new ArrayList<>(); - List addHavingConditions = new ArrayList<>(); - Set removeWhereFieldNames = new HashSet<>(); - Set removeHavingFieldNames = new HashSet<>(); - // replace where filter - updateFilters(whereExpressionList, queryData.getDimensionFilters(), - parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames); - updateDateInfo(queryData, parseInfo, filedNameToValueMap, - whereExpressionList, addWhereConditions, removeWhereFieldNames); - correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); - correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames); - // replace having filter - updateFilters(havingExpressionList, queryData.getDimensionFilters(), - parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames); - correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); - correctorSql = SqlParserRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); - - correctorSql = SqlParserAddHelper.addWhere(correctorSql, addWhereConditions); - correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions); - log.info("correctorSql after replacing:{}", correctorSql); - correctorSql = SqlParserRemoveHelper.removeNumberCondition(correctorSql); + fields = SqlParserSelectHelper.getAllFields(correctorSql); + } + if (CollectionUtils.isNotEmpty(fields) && !fields.containsAll(metrics) + && CollectionUtils.isNotEmpty(queryData.getMetrics())) { + //replace metrics + log.info("llm begin replace metrics!"); + replaceMetrics(parseInfo, metrics); + } else if (S2SQLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { + log.info("llm begin revise filters!"); + String correctorSql = reviseCorrectS2SQL(queryData, parseInfo); parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); semanticQuery.setParseInfo(parseInfo); String explainSql = semanticQuery.explain(user); @@ -310,6 +295,10 @@ public class QueryServiceImpl implements QueryService { parseInfo.getSqlInfo().setQuerySQL(explainSql); } } else { + log.info("rule begin replace metrics and revise filters!"); + //remove unvalid filters + validFilter(semanticQuery.getParseInfo().getDimensionFilters()); + validFilter(semanticQuery.getParseInfo().getMetricFilters()); //init s2sql semanticQuery.initS2Sql(user); QueryReq queryReq = new QueryReq(); @@ -324,6 +313,50 @@ public class QueryServiceImpl implements QueryService { return queryResult; } + public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) { + Map> filedNameToValueMap = new HashMap<>(); + Map> havingFiledNameToValueMap = new HashMap<>(); + + String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); + log.info("correctorSql before replacing:{}", correctorSql); + // get where filter and having filter + List whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql); + List havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql); + List addWhereConditions = new ArrayList<>(); + List addHavingConditions = new ArrayList<>(); + Set removeWhereFieldNames = new HashSet<>(); + Set removeHavingFieldNames = new HashSet<>(); + // replace where filter + updateFilters(whereExpressionList, queryData.getDimensionFilters(), + parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames); + updateDateInfo(queryData, parseInfo, filedNameToValueMap, + whereExpressionList, addWhereConditions, removeWhereFieldNames); + correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); + correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames); + // replace having filter + updateFilters(havingExpressionList, queryData.getDimensionFilters(), + parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames); + correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); + correctorSql = SqlParserRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); + + correctorSql = SqlParserAddHelper.addWhere(correctorSql, addWhereConditions); + correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions); + log.info("correctorSql after replacing:{}", correctorSql); + correctorSql = SqlParserRemoveHelper.removeNumberCondition(correctorSql); + return correctorSql; + } + + private void replaceMetrics(SemanticParseInfo parseInfo, List metrics) { + List filteredMetrics = parseInfo.getMetrics().stream() + .map(o -> o.getName()).collect(Collectors.toList()); + String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); + log.info("before replaceMetrics:{}", correctorSql); + correctorSql = SqlParserAddHelper.addFieldsToSelect(correctorSql, metrics); + correctorSql = SqlParserRemoveHelper.removeSelect(correctorSql, filteredMetrics); + log.info("after replaceMetrics:{}", correctorSql); + parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); + } + @Override public EntityInfo getEntityInfo(Long queryId, Integer parseId, User user) { ChatParseDO chatParseDO = chatService.getParseInfo(queryId, parseId); @@ -520,12 +553,27 @@ public class QueryServiceImpl implements QueryService { if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) { parseInfo.setDimensionFilters(queryData.getDimensionFilters()); } + if (CollectionUtils.isNotEmpty(queryData.getMetricFilters())) { + parseInfo.setMetricFilters(queryData.getMetricFilters()); + } if (Objects.nonNull(queryData.getDateInfo())) { parseInfo.setDateInfo(queryData.getDateInfo()); } return parseInfo; } + private void validFilter(Set filters) { + for (QueryFilter queryFilter : filters) { + if (Objects.isNull(queryFilter.getValue())) { + filters.remove(queryFilter); + } + if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty( + JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) { + filters.remove(queryFilter); + } + } + } + @Override public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception { QueryResultWithSchemaResp queryResultWithSchemaResp = new QueryResultWithSchemaResp(); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/DateFunctionHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/DateFunctionHelper.java index 15b1a7666..2e9432d31 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/DateFunctionHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/DateFunctionHelper.java @@ -35,11 +35,7 @@ public class DateFunctionHelper { } public static String getEndDateOperator(ComparisonOperator comparisonOperator) { - String operator = comparisonOperator.getStringExpression(); - if (JsqlConstants.EQUAL.equalsIgnoreCase(operator)) { - operator = "<="; - } - return operator; + return "<="; } public static String getEndDateValue(List leftExpressions) { @@ -53,4 +49,4 @@ public class DateFunctionHelper { } -} \ No newline at end of file +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java index dee99a1df..9bd780352 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java @@ -9,6 +9,7 @@ import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.GreaterThan; @@ -97,19 +98,19 @@ public class FunctionReplaceVisitor extends ExpressionVisitorAdapter { String startDataCondExpr = columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue); - if (JsqlConstants.EQUAL.equalsIgnoreCase(endDateOperator)) { result.add(CCJSqlParserUtil.parseCondExpression(condExpr)); expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(JsqlConstants.EQUAL_CONSTANT); } - comparisonOperator.setLeftExpression(null); - comparisonOperator.setRightExpression(null); - comparisonOperator.setASTNode(null); - - comparisonOperator.setLeftExpression(expression.getLeftExpression()); - comparisonOperator.setRightExpression(expression.getRightExpression()); - comparisonOperator.setASTNode(expression.getASTNode()); - + if (startDateOperator.equals("<=") || startDateOperator.equals("<")) { + comparisonOperator.setLeftExpression(new Column("1")); + comparisonOperator.setRightExpression(new LongValue(1)); + comparisonOperator.setASTNode(null); + } else { + comparisonOperator.setLeftExpression(expression.getLeftExpression()); + comparisonOperator.setRightExpression(expression.getRightExpression()); + comparisonOperator.setASTNode(expression.getASTNode()); + } result.add(CCJSqlParserUtil.parseCondExpression(startDataCondExpr)); return result; } catch (JSQLParserException e) { @@ -119,4 +120,4 @@ public class FunctionReplaceVisitor extends ExpressionVisitorAdapter { return null; } -} \ No newline at end of file +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java index 61b336697..8592785ab 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java @@ -16,7 +16,7 @@ public class JsqlConstants { public static final String EQUAL_CONSTANT = " 1 = 1 "; public static final String IN_CONSTANT = " 1 in (1) "; - public static final String LIKE_CONSTANT = "'a' like 'a'"; + public static final String LIKE_CONSTANT = "1 like 1"; public static final String IN = "IN"; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java index debb7a9b8..5b4f4ca7e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java @@ -27,6 +27,8 @@ import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; +import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SelectExpressionItem; import org.springframework.util.CollectionUtils; import java.util.List; @@ -39,6 +41,29 @@ import java.util.Set; @Slf4j public class SqlParserRemoveHelper { + public static String removeSelect(String sql, List filteredMetrics) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + List selectItemList = ((PlainSelect) selectBody).getSelectItems(); + selectItemList.removeIf(o -> { + Expression expression = ((SelectExpressionItem) o).getExpression(); + if (expression instanceof Column) { + Column column = (Column) expression; + String columnName = column.getColumnName(); + if (filteredMetrics.contains(columnName)) { + return true; + } + } + return false; + }); + ((PlainSelect) selectBody).setSelectItems(selectItemList); + return selectStatement.toString(); + } + public static String removeWhereCondition(String sql, Set removeFieldNames) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); @@ -310,16 +335,28 @@ public class SqlParserRemoveHelper { return removeSingleFilter((EqualsTo) expression); } else if (expression instanceof NotEqualsTo) { return removeSingleFilter((NotEqualsTo) expression); + } else if (expression instanceof InExpression) { + InExpression inExpression = (InExpression) expression; + Expression leftExpression = inExpression.getLeftExpression(); + return distinguishNumberCondition(leftExpression, expression); + } else if (expression instanceof LikeExpression) { + LikeExpression likeExpression = (LikeExpression) expression; + Expression leftExpression = likeExpression.getLeftExpression(); + return distinguishNumberCondition(leftExpression, expression); } return expression; } private static Expression removeSingleFilter(T comparisonExpression) { Expression leftExpression = comparisonExpression.getLeftExpression(); + return distinguishNumberCondition(leftExpression, comparisonExpression); + } + + public static Expression distinguishNumberCondition(Expression leftExpression, Expression expression) { if (leftExpression instanceof LongValue) { return null; } else { - return comparisonExpression; + return expression; } } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java index bd05e6131..8240209e6 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java @@ -155,6 +155,12 @@ class SqlParserReplaceHelperTest { + "song_publis_date = '2023-08-01' AND publish_date >= '2023-08-08' " + "ORDER BY play_count DESC LIMIT 11", replaceSql); + replaceSql = "select 品牌名称 from 互联网企业 where datediff('year', 品牌成立时间, '2023-11-04') > 17 and 注册资本 = 50000000"; + replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); + replaceSql = SqlParserRemoveHelper.removeNumberCondition(replaceSql); + Assert.assertEquals( + "SELECT 品牌名称 FROM 互联网企业 WHERE 注册资本 = 50000000 AND 品牌成立时间 < '2006-11-04'", replaceSql); + replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " + "group by MONTH(数据日期) order by sum(访问次数) desc limit 1"; @@ -250,11 +256,11 @@ class SqlParserReplaceHelperTest { "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 " + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); replaceSql = SqlParserReplaceHelper.replaceFunction(replaceSql); - + replaceSql = SqlParserRemoveHelper.removeNumberCondition(replaceSql); Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE publish_date >= '2023-08-09' " - + "AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' " - + "AND publish_date <= '2023-02-09' ORDER BY play_count DESC LIMIT 11", replaceSql); + "SELECT song_name FROM 歌曲库 WHERE singer_name = '邓紫棋' " + + "AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-02-09'" + + " ORDER BY play_count DESC LIMIT 11", replaceSql); replaceSql = SqlParserReplaceHelper.replaceFields( "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice'"