mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) aggregator supports from chinese to english in s2sql (#371)
This commit is contained in:
@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -22,7 +23,9 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFunction(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
AggregateEnum.getAggregateEnum());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
||||
replaceAlias(semanticParseInfo);
|
||||
|
||||
updateFieldNameByLinkingValue(semanticParseInfo);
|
||||
|
||||
@@ -24,8 +24,8 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
|
||||
return;
|
||||
}
|
||||
String queryMode = semanticParseInfo.getQueryMode();
|
||||
if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(
|
||||
queryMode)) {
|
||||
if (QueryManager.containsPluginQuery(queryMode)
|
||||
|| MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
|
||||
return;
|
||||
}
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
|
||||
@@ -26,8 +26,8 @@ public class EntityInfoParseResponder implements ParseResponder {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
selectedParses.forEach(parseInfo -> {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(
|
||||
queryMode)) {
|
||||
if (QueryManager.containsPluginQuery(queryMode)
|
||||
|| MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
|
||||
return;
|
||||
}
|
||||
//1. set entity info
|
||||
|
||||
@@ -65,5 +65,4 @@ public class Constants {
|
||||
public static final Long DEFAULT_FREQUENCY = 100000L;
|
||||
|
||||
public static final String TABLE_PREFIX = "t_";
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public enum AggregateEnum {
|
||||
MOST("最多", "max"),
|
||||
HIGHEST("最高", "max"),
|
||||
MAXIMUN("最大", "max"),
|
||||
LEAST("最少", "min"),
|
||||
SMALLEST("最小", "min"),
|
||||
LOWEST("最低", "min"),
|
||||
AVERAGE("平均", "avg");
|
||||
private String aggregateCh;
|
||||
private String aggregateEN;
|
||||
|
||||
AggregateEnum(String aggregateCh, String aggregateEN) {
|
||||
this.aggregateCh = aggregateCh;
|
||||
this.aggregateEN = aggregateEN;
|
||||
}
|
||||
|
||||
public String getAggregateCh() {
|
||||
return aggregateCh;
|
||||
}
|
||||
|
||||
public String getAggregateEN() {
|
||||
return aggregateEN;
|
||||
}
|
||||
|
||||
public static Map<String, String> getAggregateEnum() {
|
||||
Map<String, String> aggregateMap = Arrays.stream(AggregateEnum.values())
|
||||
.collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
|
||||
return aggregateMap;
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,19 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.BinaryExpression;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Parenthesis;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
@@ -49,6 +54,21 @@ public class SqlParserRemoveHelper {
|
||||
}
|
||||
removeWhereExpression(whereExpression, removeFieldNames);
|
||||
}
|
||||
public static String removeWhereCondition(String sql) {
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
Expression where = ((PlainSelect) selectBody).getWhere();
|
||||
Expression having = ((PlainSelect) selectBody).getHaving();
|
||||
where = filteredWhereExpression(where);
|
||||
having = filteredWhereExpression(having);
|
||||
((PlainSelect) selectBody).setWhere(where);
|
||||
((PlainSelect) selectBody).setHaving(having);
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
private static void removeWhereExpression(Expression whereExpression, Set<String> removeFieldNames) {
|
||||
if (SqlParserSelectHelper.isLogicExpression(whereExpression)) {
|
||||
@@ -171,5 +191,78 @@ public class SqlParserRemoveHelper {
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
private static Expression filteredWhereExpression(Expression where) {
|
||||
if (Objects.isNull(where)) {
|
||||
return null;
|
||||
}
|
||||
if (where instanceof Parenthesis) {
|
||||
Expression expression = filteredWhereExpression(((Parenthesis) where).getExpression());
|
||||
if (expression != null) {
|
||||
try {
|
||||
Expression parseExpression = CCJSqlParserUtil.parseExpression("(" + expression + ")");
|
||||
return parseExpression;
|
||||
} catch (JSQLParserException jsqlParserException) {
|
||||
log.info("jsqlParser has an exception:{}", jsqlParserException.toString());
|
||||
}
|
||||
} else {
|
||||
return expression;
|
||||
}
|
||||
} else if (where instanceof AndExpression) {
|
||||
AndExpression andExpression = (AndExpression) where;
|
||||
return filteredNumberExpression(andExpression);
|
||||
} else if (where instanceof OrExpression) {
|
||||
OrExpression orExpression = (OrExpression) where;
|
||||
return filteredNumberExpression(orExpression);
|
||||
} else {
|
||||
return replaceComparisonOperatorFunction(where);
|
||||
}
|
||||
return where;
|
||||
}
|
||||
|
||||
private static <T extends BinaryExpression> Expression filteredNumberExpression(T binaryExpression) {
|
||||
Expression leftExpression = filteredWhereExpression(binaryExpression.getLeftExpression());
|
||||
Expression rightExpression = filteredWhereExpression(binaryExpression.getRightExpression());
|
||||
if (leftExpression != null && rightExpression != null) {
|
||||
binaryExpression.setLeftExpression(leftExpression);
|
||||
binaryExpression.setRightExpression(rightExpression);
|
||||
return binaryExpression;
|
||||
} else if (leftExpression != null && rightExpression == null) {
|
||||
return leftExpression;
|
||||
} else if (leftExpression == null && rightExpression != null) {
|
||||
return rightExpression;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static Expression replaceComparisonOperatorFunction(Expression expression) {
|
||||
if (Objects.isNull(expression)) {
|
||||
return null;
|
||||
}
|
||||
if (expression instanceof GreaterThanEquals) {
|
||||
return removeSingleFilter((GreaterThanEquals) expression);
|
||||
} else if (expression instanceof GreaterThan) {
|
||||
return removeSingleFilter((GreaterThan) expression);
|
||||
} else if (expression instanceof MinorThan) {
|
||||
return removeSingleFilter((MinorThan) expression);
|
||||
} else if (expression instanceof MinorThanEquals) {
|
||||
return removeSingleFilter((MinorThanEquals) expression);
|
||||
} else if (expression instanceof EqualsTo) {
|
||||
return removeSingleFilter((EqualsTo) expression);
|
||||
} else if (expression instanceof NotEqualsTo) {
|
||||
return removeSingleFilter((NotEqualsTo) expression);
|
||||
}
|
||||
return expression;
|
||||
}
|
||||
|
||||
private static <T extends ComparisonOperator> Expression removeSingleFilter(T comparisonExpression) {
|
||||
Expression leftExpression = comparisonExpression.getLeftExpression();
|
||||
if (leftExpression instanceof LongValue) {
|
||||
return null;
|
||||
} else {
|
||||
return comparisonExpression;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,22 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||
import net.sf.jsqlparser.statement.select.OrderByElement;
|
||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||
import net.sf.jsqlparser.statement.select.Select;
|
||||
import net.sf.jsqlparser.statement.select.SubSelect;
|
||||
import net.sf.jsqlparser.statement.select.Join;
|
||||
import net.sf.jsqlparser.statement.select.SelectBody;
|
||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||
@@ -110,6 +121,17 @@ public class SqlParserReplaceHelper {
|
||||
if (Objects.nonNull(having)) {
|
||||
having.accept(visitor);
|
||||
}
|
||||
List<Join> joins = plainSelect.getJoins();
|
||||
if (!CollectionUtils.isEmpty(joins)) {
|
||||
for (Join join : joins) {
|
||||
join.getOnExpression().accept(visitor);
|
||||
SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody();
|
||||
List<PlainSelect> subPlainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) subSelectBody);
|
||||
for (PlainSelect subPlainSelect : subPlainSelects) {
|
||||
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static String replaceFunction(String sql, Map<String, String> functionMap) {
|
||||
@@ -143,6 +165,12 @@ public class SqlParserReplaceHelper {
|
||||
for (SelectItem selectItem : plainSelect.getSelectItems()) {
|
||||
selectItem.accept(visitor);
|
||||
}
|
||||
Expression having = plainSelect.getHaving();
|
||||
if (Objects.nonNull(having)) {
|
||||
replaceHavingFunction(functionMap, having);
|
||||
}
|
||||
List<OrderByElement> orderByElementList = plainSelect.getOrderByElements();
|
||||
replaceOrderByFunction(functionMap, orderByElementList);
|
||||
}
|
||||
|
||||
public static String replaceFunction(String sql) {
|
||||
@@ -172,6 +200,67 @@ public class SqlParserReplaceHelper {
|
||||
addWaitingExpression(plainSelect, where, waitingForAdds);
|
||||
}
|
||||
|
||||
private static void replaceHavingFunction(Map<String, String> functionMap, Expression having) {
|
||||
if (Objects.nonNull(having)) {
|
||||
if (having instanceof AndExpression) {
|
||||
AndExpression andExpression = (AndExpression) having;
|
||||
replaceHavingFunction(functionMap, andExpression.getLeftExpression());
|
||||
replaceHavingFunction(functionMap, andExpression.getRightExpression());
|
||||
} else if (having instanceof OrExpression) {
|
||||
OrExpression orExpression = (OrExpression) having;
|
||||
replaceHavingFunction(functionMap, orExpression.getLeftExpression());
|
||||
replaceHavingFunction(functionMap, orExpression.getRightExpression());
|
||||
} else {
|
||||
replaceComparisonOperatorFunction(functionMap, having);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void replaceComparisonOperatorFunction(Map<String, String> functionMap, Expression expression) {
|
||||
if (Objects.isNull(expression)) {
|
||||
return;
|
||||
}
|
||||
if (expression instanceof GreaterThanEquals) {
|
||||
replaceFilterFunction(functionMap, (GreaterThanEquals) expression);
|
||||
} else if (expression instanceof GreaterThan) {
|
||||
replaceFilterFunction(functionMap, (GreaterThan) expression);
|
||||
} else if (expression instanceof MinorThan) {
|
||||
replaceFilterFunction(functionMap, (MinorThan) expression);
|
||||
} else if (expression instanceof MinorThanEquals) {
|
||||
replaceFilterFunction(functionMap, (MinorThanEquals) expression);
|
||||
} else if (expression instanceof EqualsTo) {
|
||||
replaceFilterFunction(functionMap, (EqualsTo) expression);
|
||||
} else if (expression instanceof NotEqualsTo) {
|
||||
replaceFilterFunction(functionMap, (NotEqualsTo) expression);
|
||||
}
|
||||
}
|
||||
|
||||
private static void replaceOrderByFunction(Map<String, String> functionMap,
|
||||
List<OrderByElement> orderByElementList) {
|
||||
if (Objects.isNull(orderByElementList)) {
|
||||
return;
|
||||
}
|
||||
for (OrderByElement orderByElement : orderByElementList) {
|
||||
if (orderByElement.getExpression() instanceof Function) {
|
||||
Function function = (Function) orderByElement.getExpression();
|
||||
if (functionMap.containsKey(function.getName())) {
|
||||
function.setName(functionMap.get(function.getName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static <T extends ComparisonOperator> void replaceFilterFunction(
|
||||
Map<String, String> functionMap, T comparisonExpression) {
|
||||
Expression expression = comparisonExpression.getLeftExpression();
|
||||
if (expression instanceof Function) {
|
||||
Function function = (Function) expression;
|
||||
if (functionMap.containsKey(function.getName())) {
|
||||
function.setName(functionMap.get(function.getName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void addWaitingExpression(PlainSelect plainSelect, Expression where,
|
||||
List<Expression> waitingForAdds) {
|
||||
if (CollectionUtils.isEmpty(waitingForAdds)) {
|
||||
@@ -204,6 +293,17 @@ public class SqlParserReplaceHelper {
|
||||
plainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName));
|
||||
}
|
||||
});
|
||||
List<Join> joins = painSelect.getJoins();
|
||||
if (!CollectionUtils.isEmpty(joins)) {
|
||||
for (Join join : joins) {
|
||||
SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody();
|
||||
List<PlainSelect> subPlainSelects = SqlParserSelectHelper.getPlainSelects(
|
||||
(PlainSelect) subSelectBody);
|
||||
for (PlainSelect subPlainSelect : subPlainSelects) {
|
||||
subPlainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
@@ -10,6 +10,45 @@ import org.junit.jupiter.api.Test;
|
||||
*/
|
||||
class SqlParserRemoveHelperTest {
|
||||
|
||||
@Test
|
||||
void removeWhereHavingCondition() {
|
||||
String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "
|
||||
+ "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1";
|
||||
sql = SqlParserRemoveHelper.removeWhereCondition(sql);
|
||||
System.out.println(sql);
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE sum(粉丝数) > 20000 AND sum(播放量) > 20000 HAVING sum(播放量) > 20000",
|
||||
sql);
|
||||
sql = "SELECT 歌曲,sum(播放量) FROM 歌曲库\n"
|
||||
+ "WHERE (歌手名 = '张三' AND 2 > 1) AND 数据日期 = '2023-11-07'\n"
|
||||
+ "GROUP BY 歌曲名 HAVING sum(播放量) > 100000";
|
||||
sql = SqlParserRemoveHelper.removeWhereCondition(sql);
|
||||
System.out.println(sql);
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲, sum(播放量) FROM 歌曲库 WHERE (歌手名 = '张三') "
|
||||
+ "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING sum(播放量) > 100000",
|
||||
sql);
|
||||
sql = "SELECT 歌曲名,sum(播放量) FROM 歌曲库 WHERE (1 = 1 AND 1 = 1 AND 2 > 1 )"
|
||||
+ "AND 1 = 1 AND 歌曲类型 IN ('类型一', '类型二') AND 歌手名 IN ('林俊杰', '周杰伦')"
|
||||
+ "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING 2 > 1 AND SUM(播放量) >= 1000";
|
||||
sql = SqlParserRemoveHelper.removeWhereCondition(sql);
|
||||
System.out.println(sql);
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名, sum(播放量) FROM 歌曲库 WHERE 歌曲类型 IN ('类型一', '类型二') "
|
||||
+ "AND 歌手名 IN ('林俊杰', '周杰伦') AND 数据日期 = '2023-11-07' "
|
||||
+ "GROUP BY 歌曲名 HAVING SUM(播放量) >= 1000",
|
||||
sql);
|
||||
|
||||
sql = "SELECT 品牌名称,法人 FROM 互联网企业 WHERE (2 > 1 AND 1 = 1) AND 数据日期 = '2023-10-31'"
|
||||
+ "GROUP BY 品牌名称, 法人 HAVING 2 > 1 AND sum(注册资本) > 100000000 AND sum(营收占比) = 0.5 and 1 = 1";
|
||||
sql = SqlParserRemoveHelper.removeWhereCondition(sql);
|
||||
System.out.println(sql);
|
||||
Assert.assertEquals(
|
||||
"SELECT 品牌名称, 法人 FROM 互联网企业 WHERE 数据日期 = '2023-10-31' GROUP BY "
|
||||
+ "品牌名称, 法人 HAVING sum(注册资本) > 100000000 AND sum(营收占比) = 0.5",
|
||||
sql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void removeHavingCondition() {
|
||||
String sql = "select 歌曲名 from 歌曲库 where 歌手名 = '周杰伦' HAVING sum(播放量) > 20000";
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -343,7 +344,19 @@ class SqlParserReplaceHelperTest {
|
||||
@Test
|
||||
void replaceFunctionName() {
|
||||
|
||||
String sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where"
|
||||
String sql = "select 公司名称,平均(注册资本),总部地点 from 互联网企业 where\n"
|
||||
+ "年营业额 >= 28800000000 and 最大(注册资本)>10000 \n"
|
||||
+ " group by 公司名称 having 平均(注册资本)>10000 order by \n"
|
||||
+ "平均(注册资本) desc limit 5";
|
||||
Map<String, String> map = new HashMap<>();
|
||||
map.put("平均", "avg");
|
||||
map.put("最大", "max");
|
||||
sql = SqlParserReplaceHelper.replaceFunction(sql, map);
|
||||
System.out.println(sql);
|
||||
Assert.assertEquals("SELECT 公司名称, avg(注册资本), 总部地点 FROM 互联网企业 WHERE 年营业额 >= 28800000000 AND "
|
||||
+ "max(注册资本) > 10000 GROUP BY 公司名称 HAVING avg(注册资本) > 10000 ORDER BY avg(注册资本) DESC LIMIT 5", sql);
|
||||
|
||||
sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where"
|
||||
+ " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)";
|
||||
Map<String, String> functionMap = new HashMap<>();
|
||||
functionMap.put("MONTH".toLowerCase(), "toMonth");
|
||||
|
||||
Reference in New Issue
Block a user