From cdb84716b78480a962f075c83d6f437e31169b5e Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Mon, 13 Nov 2023 14:51:23 +0800 Subject: [PATCH] (improvement)(chat) aggregator supports from chinese to english in s2sql (#371) --- .../chat/corrector/SchemaCorrector.java | 5 +- .../execute/EntityInfoExecuteResponder.java | 4 +- .../parse/EntityInfoParseResponder.java | 4 +- .../supersonic/common/pojo/Constants.java | 1 - .../common/util/jsqlparser/AggregateEnum.java | 36 +++++++ .../jsqlparser/SqlParserRemoveHelper.java | 93 ++++++++++++++++ .../jsqlparser/SqlParserReplaceHelper.java | 100 ++++++++++++++++++ .../jsqlparser/SqlParserRemoveHelperTest.java | 39 +++++++ .../SqlParserReplaceHelperTest.java | 23 +++- 9 files changed, 294 insertions(+), 11 deletions(-) create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/AggregateEnum.java diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java index 5c54da6fb..2362bf0fe 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java @@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import java.util.List; import java.util.Map; @@ -22,7 +23,9 @@ public class SchemaCorrector extends BaseSemanticCorrector { @Override public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { - + String sql = SqlParserReplaceHelper.replaceFunction(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), + AggregateEnum.getAggregateEnum()); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql); replaceAlias(semanticParseInfo); updateFieldNameByLinkingValue(semanticParseInfo); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java index 928ebb7e5..05418c896 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java @@ -24,8 +24,8 @@ public class EntityInfoExecuteResponder implements ExecuteResponder { return; } String queryMode = semanticParseInfo.getQueryMode(); - if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase( - queryMode)) { + if (QueryManager.containsPluginQuery(queryMode) + || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) { return; } SemanticService semanticService = ContextUtils.getBean(SemanticService.class); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java index 334ae51bb..b563884e4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java @@ -26,8 +26,8 @@ public class EntityInfoParseResponder implements ParseResponder { QueryReq queryReq = queryContext.getRequest(); selectedParses.forEach(parseInfo -> { String queryMode = parseInfo.getQueryMode(); - if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase( - queryMode)) { + if (QueryManager.containsPluginQuery(queryMode) + || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) { return; } //1. set entity info diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/Constants.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Constants.java index cd08b6885..31185ce1c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/Constants.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Constants.java @@ -65,5 +65,4 @@ public class Constants { public static final Long DEFAULT_FREQUENCY = 100000L; public static final String TABLE_PREFIX = "t_"; - } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/AggregateEnum.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/AggregateEnum.java new file mode 100644 index 000000000..04fc82566 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/AggregateEnum.java @@ -0,0 +1,36 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; + +public enum AggregateEnum { + MOST("最多", "max"), + HIGHEST("最高", "max"), + MAXIMUN("最大", "max"), + LEAST("最少", "min"), + SMALLEST("最小", "min"), + LOWEST("最低", "min"), + AVERAGE("平均", "avg"); + private String aggregateCh; + private String aggregateEN; + + AggregateEnum(String aggregateCh, String aggregateEN) { + this.aggregateCh = aggregateCh; + this.aggregateEN = aggregateEN; + } + + public String getAggregateCh() { + return aggregateCh; + } + + public String getAggregateEN() { + return aggregateEN; + } + + public static Map getAggregateEnum() { + Map aggregateMap = Arrays.stream(AggregateEnum.values()) + .collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN)); + return aggregateMap; + } +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java index 8707c6d13..13e590d42 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java @@ -1,14 +1,19 @@ package com.tencent.supersonic.common.util.jsqlparser; import java.util.List; +import java.util.Objects; import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Parenthesis; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; import net.sf.jsqlparser.expression.operators.relational.GreaterThan; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.InExpression; @@ -49,6 +54,21 @@ public class SqlParserRemoveHelper { } removeWhereExpression(whereExpression, removeFieldNames); } + public static String removeWhereCondition(String sql) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + Expression where = ((PlainSelect) selectBody).getWhere(); + Expression having = ((PlainSelect) selectBody).getHaving(); + where = filteredWhereExpression(where); + having = filteredWhereExpression(having); + ((PlainSelect) selectBody).setWhere(where); + ((PlainSelect) selectBody).setHaving(having); + return selectStatement.toString(); + } private static void removeWhereExpression(Expression whereExpression, Set removeFieldNames) { if (SqlParserSelectHelper.isLogicExpression(whereExpression)) { @@ -171,5 +191,78 @@ public class SqlParserRemoveHelper { return selectStatement.toString(); } + private static Expression filteredWhereExpression(Expression where) { + if (Objects.isNull(where)) { + return null; + } + if (where instanceof Parenthesis) { + Expression expression = filteredWhereExpression(((Parenthesis) where).getExpression()); + if (expression != null) { + try { + Expression parseExpression = CCJSqlParserUtil.parseExpression("(" + expression + ")"); + return parseExpression; + } catch (JSQLParserException jsqlParserException) { + log.info("jsqlParser has an exception:{}", jsqlParserException.toString()); + } + } else { + return expression; + } + } else if (where instanceof AndExpression) { + AndExpression andExpression = (AndExpression) where; + return filteredNumberExpression(andExpression); + } else if (where instanceof OrExpression) { + OrExpression orExpression = (OrExpression) where; + return filteredNumberExpression(orExpression); + } else { + return replaceComparisonOperatorFunction(where); + } + return where; + } + + private static Expression filteredNumberExpression(T binaryExpression) { + Expression leftExpression = filteredWhereExpression(binaryExpression.getLeftExpression()); + Expression rightExpression = filteredWhereExpression(binaryExpression.getRightExpression()); + if (leftExpression != null && rightExpression != null) { + binaryExpression.setLeftExpression(leftExpression); + binaryExpression.setRightExpression(rightExpression); + return binaryExpression; + } else if (leftExpression != null && rightExpression == null) { + return leftExpression; + } else if (leftExpression == null && rightExpression != null) { + return rightExpression; + } else { + return null; + } + } + + private static Expression replaceComparisonOperatorFunction(Expression expression) { + if (Objects.isNull(expression)) { + return null; + } + if (expression instanceof GreaterThanEquals) { + return removeSingleFilter((GreaterThanEquals) expression); + } else if (expression instanceof GreaterThan) { + return removeSingleFilter((GreaterThan) expression); + } else if (expression instanceof MinorThan) { + return removeSingleFilter((MinorThan) expression); + } else if (expression instanceof MinorThanEquals) { + return removeSingleFilter((MinorThanEquals) expression); + } else if (expression instanceof EqualsTo) { + return removeSingleFilter((EqualsTo) expression); + } else if (expression instanceof NotEqualsTo) { + return removeSingleFilter((NotEqualsTo) expression); + } + return expression; + } + + private static Expression removeSingleFilter(T comparisonExpression) { + Expression leftExpression = comparisonExpression.getLeftExpression(); + if (leftExpression instanceof LongValue) { + return null; + } else { + return comparisonExpression; + } + } + } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java index cc065894b..49501b2b7 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java @@ -6,11 +6,22 @@ import java.util.Objects; import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.conditional.OrExpression; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; +import net.sf.jsqlparser.expression.operators.relational.MinorThan; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SubSelect; +import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; @@ -110,6 +121,17 @@ public class SqlParserReplaceHelper { if (Objects.nonNull(having)) { having.accept(visitor); } + List joins = plainSelect.getJoins(); + if (!CollectionUtils.isEmpty(joins)) { + for (Join join : joins) { + join.getOnExpression().accept(visitor); + SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody(); + List subPlainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) subSelectBody); + for (PlainSelect subPlainSelect : subPlainSelects) { + replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect); + } + } + } } public static String replaceFunction(String sql, Map functionMap) { @@ -143,6 +165,12 @@ public class SqlParserReplaceHelper { for (SelectItem selectItem : plainSelect.getSelectItems()) { selectItem.accept(visitor); } + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + replaceHavingFunction(functionMap, having); + } + List orderByElementList = plainSelect.getOrderByElements(); + replaceOrderByFunction(functionMap, orderByElementList); } public static String replaceFunction(String sql) { @@ -172,6 +200,67 @@ public class SqlParserReplaceHelper { addWaitingExpression(plainSelect, where, waitingForAdds); } + private static void replaceHavingFunction(Map functionMap, Expression having) { + if (Objects.nonNull(having)) { + if (having instanceof AndExpression) { + AndExpression andExpression = (AndExpression) having; + replaceHavingFunction(functionMap, andExpression.getLeftExpression()); + replaceHavingFunction(functionMap, andExpression.getRightExpression()); + } else if (having instanceof OrExpression) { + OrExpression orExpression = (OrExpression) having; + replaceHavingFunction(functionMap, orExpression.getLeftExpression()); + replaceHavingFunction(functionMap, orExpression.getRightExpression()); + } else { + replaceComparisonOperatorFunction(functionMap, having); + } + } + } + + private static void replaceComparisonOperatorFunction(Map functionMap, Expression expression) { + if (Objects.isNull(expression)) { + return; + } + if (expression instanceof GreaterThanEquals) { + replaceFilterFunction(functionMap, (GreaterThanEquals) expression); + } else if (expression instanceof GreaterThan) { + replaceFilterFunction(functionMap, (GreaterThan) expression); + } else if (expression instanceof MinorThan) { + replaceFilterFunction(functionMap, (MinorThan) expression); + } else if (expression instanceof MinorThanEquals) { + replaceFilterFunction(functionMap, (MinorThanEquals) expression); + } else if (expression instanceof EqualsTo) { + replaceFilterFunction(functionMap, (EqualsTo) expression); + } else if (expression instanceof NotEqualsTo) { + replaceFilterFunction(functionMap, (NotEqualsTo) expression); + } + } + + private static void replaceOrderByFunction(Map functionMap, + List orderByElementList) { + if (Objects.isNull(orderByElementList)) { + return; + } + for (OrderByElement orderByElement : orderByElementList) { + if (orderByElement.getExpression() instanceof Function) { + Function function = (Function) orderByElement.getExpression(); + if (functionMap.containsKey(function.getName())) { + function.setName(functionMap.get(function.getName())); + } + } + } + } + + private static void replaceFilterFunction( + Map functionMap, T comparisonExpression) { + Expression expression = comparisonExpression.getLeftExpression(); + if (expression instanceof Function) { + Function function = (Function) expression; + if (functionMap.containsKey(function.getName())) { + function.setName(functionMap.get(function.getName())); + } + } + } + private static void addWaitingExpression(PlainSelect plainSelect, Expression where, List waitingForAdds) { if (CollectionUtils.isEmpty(waitingForAdds)) { @@ -204,6 +293,17 @@ public class SqlParserReplaceHelper { plainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); } }); + List joins = painSelect.getJoins(); + if (!CollectionUtils.isEmpty(joins)) { + for (Join join : joins) { + SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody(); + List subPlainSelects = SqlParserSelectHelper.getPlainSelects( + (PlainSelect) subSelectBody); + for (PlainSelect subPlainSelect : subPlainSelects) { + subPlainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); + } + } + } } return selectStatement.toString(); } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java index 7c4365e09..d1b71cae6 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelperTest.java @@ -10,6 +10,45 @@ import org.junit.jupiter.api.Test; */ class SqlParserRemoveHelperTest { + @Test + void removeWhereHavingCondition() { + String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and " + + "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1"; + sql = SqlParserRemoveHelper.removeWhereCondition(sql); + System.out.println(sql); + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE sum(粉丝数) > 20000 AND sum(播放量) > 20000 HAVING sum(播放量) > 20000", + sql); + sql = "SELECT 歌曲,sum(播放量) FROM 歌曲库\n" + + "WHERE (歌手名 = '张三' AND 2 > 1) AND 数据日期 = '2023-11-07'\n" + + "GROUP BY 歌曲名 HAVING sum(播放量) > 100000"; + sql = SqlParserRemoveHelper.removeWhereCondition(sql); + System.out.println(sql); + Assert.assertEquals( + "SELECT 歌曲, sum(播放量) FROM 歌曲库 WHERE (歌手名 = '张三') " + + "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING sum(播放量) > 100000", + sql); + sql = "SELECT 歌曲名,sum(播放量) FROM 歌曲库 WHERE (1 = 1 AND 1 = 1 AND 2 > 1 )" + + "AND 1 = 1 AND 歌曲类型 IN ('类型一', '类型二') AND 歌手名 IN ('林俊杰', '周杰伦')" + + "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING 2 > 1 AND SUM(播放量) >= 1000"; + sql = SqlParserRemoveHelper.removeWhereCondition(sql); + System.out.println(sql); + Assert.assertEquals( + "SELECT 歌曲名, sum(播放量) FROM 歌曲库 WHERE 歌曲类型 IN ('类型一', '类型二') " + + "AND 歌手名 IN ('林俊杰', '周杰伦') AND 数据日期 = '2023-11-07' " + + "GROUP BY 歌曲名 HAVING SUM(播放量) >= 1000", + sql); + + sql = "SELECT 品牌名称,法人 FROM 互联网企业 WHERE (2 > 1 AND 1 = 1) AND 数据日期 = '2023-10-31'" + + "GROUP BY 品牌名称, 法人 HAVING 2 > 1 AND sum(注册资本) > 100000000 AND sum(营收占比) = 0.5 and 1 = 1"; + sql = SqlParserRemoveHelper.removeWhereCondition(sql); + System.out.println(sql); + Assert.assertEquals( + "SELECT 品牌名称, 法人 FROM 互联网企业 WHERE 数据日期 = '2023-10-31' GROUP BY " + + "品牌名称, 法人 HAVING sum(注册资本) > 100000000 AND sum(营收占比) = 0.5", + sql); + } + @Test void removeHavingCondition() { String sql = "select 歌曲名 from 歌曲库 where 歌手名 = '周杰伦' HAVING sum(播放量) > 20000"; diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java index 6fe02a9e3..bd05e6131 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java @@ -1,10 +1,11 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; import java.util.Set; +import java.util.HashSet; +import java.util.Collections; +import java.util.Map; +import java.util.HashMap; + import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -343,7 +344,19 @@ class SqlParserReplaceHelperTest { @Test void replaceFunctionName() { - String sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + String sql = "select 公司名称,平均(注册资本),总部地点 from 互联网企业 where\n" + + "年营业额 >= 28800000000 and 最大(注册资本)>10000 \n" + + " group by 公司名称 having 平均(注册资本)>10000 order by \n" + + "平均(注册资本) desc limit 5"; + Map map = new HashMap<>(); + map.put("平均", "avg"); + map.put("最大", "max"); + sql = SqlParserReplaceHelper.replaceFunction(sql, map); + System.out.println(sql); + Assert.assertEquals("SELECT 公司名称, avg(注册资本), 总部地点 FROM 互联网企业 WHERE 年营业额 >= 28800000000 AND " + + "max(注册资本) > 10000 GROUP BY 公司名称 HAVING avg(注册资本) > 10000 ORDER BY avg(注册资本) DESC LIMIT 5", sql); + + sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; Map functionMap = new HashMap<>(); functionMap.put("MONTH".toLowerCase(), "toMonth");