(improvement)(chat) aggregator supports from chinese to english in s2sql (#371)

This commit is contained in:
mainmain
2023-11-13 14:51:23 +08:00
committed by GitHub
parent 731238de08
commit cdb84716b7
9 changed files with 294 additions and 11 deletions

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -22,7 +23,9 @@ public class SchemaCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String sql = SqlParserReplaceHelper.replaceFunction(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
AggregateEnum.getAggregateEnum());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
replaceAlias(semanticParseInfo); replaceAlias(semanticParseInfo);
updateFieldNameByLinkingValue(semanticParseInfo); updateFieldNameByLinkingValue(semanticParseInfo);

View File

@@ -24,8 +24,8 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
return; return;
} }
String queryMode = semanticParseInfo.getQueryMode(); String queryMode = semanticParseInfo.getQueryMode();
if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase( if (QueryManager.containsPluginQuery(queryMode)
queryMode)) { || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
return; return;
} }
SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticService semanticService = ContextUtils.getBean(SemanticService.class);

View File

@@ -26,8 +26,8 @@ public class EntityInfoParseResponder implements ParseResponder {
QueryReq queryReq = queryContext.getRequest(); QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> { selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode(); String queryMode = parseInfo.getQueryMode();
if (QueryManager.containsPluginQuery(queryMode) || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase( if (QueryManager.containsPluginQuery(queryMode)
queryMode)) { || MetricInterpretQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
return; return;
} }
//1. set entity info //1. set entity info

View File

@@ -65,5 +65,4 @@ public class Constants {
public static final Long DEFAULT_FREQUENCY = 100000L; public static final Long DEFAULT_FREQUENCY = 100000L;
public static final String TABLE_PREFIX = "t_"; public static final String TABLE_PREFIX = "t_";
} }

View File

@@ -0,0 +1,36 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
public enum AggregateEnum {
MOST("最多", "max"),
HIGHEST("最高", "max"),
MAXIMUN("最大", "max"),
LEAST("最少", "min"),
SMALLEST("最小", "min"),
LOWEST("最低", "min"),
AVERAGE("平均", "avg");
private String aggregateCh;
private String aggregateEN;
AggregateEnum(String aggregateCh, String aggregateEN) {
this.aggregateCh = aggregateCh;
this.aggregateEN = aggregateEN;
}
public String getAggregateCh() {
return aggregateCh;
}
public String getAggregateEN() {
return aggregateEN;
}
public static Map<String, String> getAggregateEnum() {
Map<String, String> aggregateMap = Arrays.stream(AggregateEnum.values())
.collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
return aggregateMap;
}
}

View File

@@ -1,14 +1,19 @@
package com.tencent.supersonic.common.util.jsqlparser; package com.tencent.supersonic.common.util.jsqlparser;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan; import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression; import net.sf.jsqlparser.expression.operators.relational.InExpression;
@@ -49,6 +54,21 @@ public class SqlParserRemoveHelper {
} }
removeWhereExpression(whereExpression, removeFieldNames); removeWhereExpression(whereExpression, removeFieldNames);
} }
public static String removeWhereCondition(String sql) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
Expression where = ((PlainSelect) selectBody).getWhere();
Expression having = ((PlainSelect) selectBody).getHaving();
where = filteredWhereExpression(where);
having = filteredWhereExpression(having);
((PlainSelect) selectBody).setWhere(where);
((PlainSelect) selectBody).setHaving(having);
return selectStatement.toString();
}
private static void removeWhereExpression(Expression whereExpression, Set<String> removeFieldNames) { private static void removeWhereExpression(Expression whereExpression, Set<String> removeFieldNames) {
if (SqlParserSelectHelper.isLogicExpression(whereExpression)) { if (SqlParserSelectHelper.isLogicExpression(whereExpression)) {
@@ -171,5 +191,78 @@ public class SqlParserRemoveHelper {
return selectStatement.toString(); return selectStatement.toString();
} }
private static Expression filteredWhereExpression(Expression where) {
if (Objects.isNull(where)) {
return null;
}
if (where instanceof Parenthesis) {
Expression expression = filteredWhereExpression(((Parenthesis) where).getExpression());
if (expression != null) {
try {
Expression parseExpression = CCJSqlParserUtil.parseExpression("(" + expression + ")");
return parseExpression;
} catch (JSQLParserException jsqlParserException) {
log.info("jsqlParser has an exception:{}", jsqlParserException.toString());
}
} else {
return expression;
}
} else if (where instanceof AndExpression) {
AndExpression andExpression = (AndExpression) where;
return filteredNumberExpression(andExpression);
} else if (where instanceof OrExpression) {
OrExpression orExpression = (OrExpression) where;
return filteredNumberExpression(orExpression);
} else {
return replaceComparisonOperatorFunction(where);
}
return where;
}
private static <T extends BinaryExpression> Expression filteredNumberExpression(T binaryExpression) {
Expression leftExpression = filteredWhereExpression(binaryExpression.getLeftExpression());
Expression rightExpression = filteredWhereExpression(binaryExpression.getRightExpression());
if (leftExpression != null && rightExpression != null) {
binaryExpression.setLeftExpression(leftExpression);
binaryExpression.setRightExpression(rightExpression);
return binaryExpression;
} else if (leftExpression != null && rightExpression == null) {
return leftExpression;
} else if (leftExpression == null && rightExpression != null) {
return rightExpression;
} else {
return null;
}
}
private static Expression replaceComparisonOperatorFunction(Expression expression) {
if (Objects.isNull(expression)) {
return null;
}
if (expression instanceof GreaterThanEquals) {
return removeSingleFilter((GreaterThanEquals) expression);
} else if (expression instanceof GreaterThan) {
return removeSingleFilter((GreaterThan) expression);
} else if (expression instanceof MinorThan) {
return removeSingleFilter((MinorThan) expression);
} else if (expression instanceof MinorThanEquals) {
return removeSingleFilter((MinorThanEquals) expression);
} else if (expression instanceof EqualsTo) {
return removeSingleFilter((EqualsTo) expression);
} else if (expression instanceof NotEqualsTo) {
return removeSingleFilter((NotEqualsTo) expression);
}
return expression;
}
private static <T extends ComparisonOperator> Expression removeSingleFilter(T comparisonExpression) {
Expression leftExpression = comparisonExpression.getLeftExpression();
if (leftExpression instanceof LongValue) {
return null;
} else {
return comparisonExpression;
}
}
} }

View File

@@ -6,11 +6,22 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SubSelect;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
@@ -110,6 +121,17 @@ public class SqlParserReplaceHelper {
if (Objects.nonNull(having)) { if (Objects.nonNull(having)) {
having.accept(visitor); having.accept(visitor);
} }
List<Join> joins = plainSelect.getJoins();
if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) {
join.getOnExpression().accept(visitor);
SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody();
List<PlainSelect> subPlainSelects = SqlParserSelectHelper.getPlainSelects((PlainSelect) subSelectBody);
for (PlainSelect subPlainSelect : subPlainSelects) {
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect);
}
}
}
} }
public static String replaceFunction(String sql, Map<String, String> functionMap) { public static String replaceFunction(String sql, Map<String, String> functionMap) {
@@ -143,6 +165,12 @@ public class SqlParserReplaceHelper {
for (SelectItem selectItem : plainSelect.getSelectItems()) { for (SelectItem selectItem : plainSelect.getSelectItems()) {
selectItem.accept(visitor); selectItem.accept(visitor);
} }
Expression having = plainSelect.getHaving();
if (Objects.nonNull(having)) {
replaceHavingFunction(functionMap, having);
}
List<OrderByElement> orderByElementList = plainSelect.getOrderByElements();
replaceOrderByFunction(functionMap, orderByElementList);
} }
public static String replaceFunction(String sql) { public static String replaceFunction(String sql) {
@@ -172,6 +200,67 @@ public class SqlParserReplaceHelper {
addWaitingExpression(plainSelect, where, waitingForAdds); addWaitingExpression(plainSelect, where, waitingForAdds);
} }
private static void replaceHavingFunction(Map<String, String> functionMap, Expression having) {
if (Objects.nonNull(having)) {
if (having instanceof AndExpression) {
AndExpression andExpression = (AndExpression) having;
replaceHavingFunction(functionMap, andExpression.getLeftExpression());
replaceHavingFunction(functionMap, andExpression.getRightExpression());
} else if (having instanceof OrExpression) {
OrExpression orExpression = (OrExpression) having;
replaceHavingFunction(functionMap, orExpression.getLeftExpression());
replaceHavingFunction(functionMap, orExpression.getRightExpression());
} else {
replaceComparisonOperatorFunction(functionMap, having);
}
}
}
private static void replaceComparisonOperatorFunction(Map<String, String> functionMap, Expression expression) {
if (Objects.isNull(expression)) {
return;
}
if (expression instanceof GreaterThanEquals) {
replaceFilterFunction(functionMap, (GreaterThanEquals) expression);
} else if (expression instanceof GreaterThan) {
replaceFilterFunction(functionMap, (GreaterThan) expression);
} else if (expression instanceof MinorThan) {
replaceFilterFunction(functionMap, (MinorThan) expression);
} else if (expression instanceof MinorThanEquals) {
replaceFilterFunction(functionMap, (MinorThanEquals) expression);
} else if (expression instanceof EqualsTo) {
replaceFilterFunction(functionMap, (EqualsTo) expression);
} else if (expression instanceof NotEqualsTo) {
replaceFilterFunction(functionMap, (NotEqualsTo) expression);
}
}
private static void replaceOrderByFunction(Map<String, String> functionMap,
List<OrderByElement> orderByElementList) {
if (Objects.isNull(orderByElementList)) {
return;
}
for (OrderByElement orderByElement : orderByElementList) {
if (orderByElement.getExpression() instanceof Function) {
Function function = (Function) orderByElement.getExpression();
if (functionMap.containsKey(function.getName())) {
function.setName(functionMap.get(function.getName()));
}
}
}
}
private static <T extends ComparisonOperator> void replaceFilterFunction(
Map<String, String> functionMap, T comparisonExpression) {
Expression expression = comparisonExpression.getLeftExpression();
if (expression instanceof Function) {
Function function = (Function) expression;
if (functionMap.containsKey(function.getName())) {
function.setName(functionMap.get(function.getName()));
}
}
}
private static void addWaitingExpression(PlainSelect plainSelect, Expression where, private static void addWaitingExpression(PlainSelect plainSelect, Expression where,
List<Expression> waitingForAdds) { List<Expression> waitingForAdds) {
if (CollectionUtils.isEmpty(waitingForAdds)) { if (CollectionUtils.isEmpty(waitingForAdds)) {
@@ -204,6 +293,17 @@ public class SqlParserReplaceHelper {
plainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); plainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName));
} }
}); });
List<Join> joins = painSelect.getJoins();
if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) {
SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody();
List<PlainSelect> subPlainSelects = SqlParserSelectHelper.getPlainSelects(
(PlainSelect) subSelectBody);
for (PlainSelect subPlainSelect : subPlainSelects) {
subPlainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName));
}
}
}
} }
return selectStatement.toString(); return selectStatement.toString();
} }

View File

@@ -10,6 +10,45 @@ import org.junit.jupiter.api.Test;
*/ */
class SqlParserRemoveHelperTest { class SqlParserRemoveHelperTest {
@Test
void removeWhereHavingCondition() {
String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "
+ "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1";
sql = SqlParserRemoveHelper.removeWhereCondition(sql);
System.out.println(sql);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE sum(粉丝数) > 20000 AND sum(播放量) > 20000 HAVING sum(播放量) > 20000",
sql);
sql = "SELECT 歌曲,sum(播放量) FROM 歌曲库\n"
+ "WHERE (歌手名 = '张三' AND 2 > 1) AND 数据日期 = '2023-11-07'\n"
+ "GROUP BY 歌曲名 HAVING sum(播放量) > 100000";
sql = SqlParserRemoveHelper.removeWhereCondition(sql);
System.out.println(sql);
Assert.assertEquals(
"SELECT 歌曲, sum(播放量) FROM 歌曲库 WHERE (歌手名 = '张三') "
+ "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING sum(播放量) > 100000",
sql);
sql = "SELECT 歌曲名,sum(播放量) FROM 歌曲库 WHERE (1 = 1 AND 1 = 1 AND 2 > 1 )"
+ "AND 1 = 1 AND 歌曲类型 IN ('类型一', '类型二') AND 歌手名 IN ('林俊杰', '周杰伦')"
+ "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING 2 > 1 AND SUM(播放量) >= 1000";
sql = SqlParserRemoveHelper.removeWhereCondition(sql);
System.out.println(sql);
Assert.assertEquals(
"SELECT 歌曲名, sum(播放量) FROM 歌曲库 WHERE 歌曲类型 IN ('类型一', '类型二') "
+ "AND 歌手名 IN ('林俊杰', '周杰伦') AND 数据日期 = '2023-11-07' "
+ "GROUP BY 歌曲名 HAVING SUM(播放量) >= 1000",
sql);
sql = "SELECT 品牌名称,法人 FROM 互联网企业 WHERE (2 > 1 AND 1 = 1) AND 数据日期 = '2023-10-31'"
+ "GROUP BY 品牌名称, 法人 HAVING 2 > 1 AND sum(注册资本) > 100000000 AND sum(营收占比) = 0.5 and 1 = 1";
sql = SqlParserRemoveHelper.removeWhereCondition(sql);
System.out.println(sql);
Assert.assertEquals(
"SELECT 品牌名称, 法人 FROM 互联网企业 WHERE 数据日期 = '2023-10-31' GROUP BY "
+ "品牌名称, 法人 HAVING sum(注册资本) > 100000000 AND sum(营收占比) = 0.5",
sql);
}
@Test @Test
void removeHavingCondition() { void removeHavingCondition() {
String sql = "select 歌曲名 from 歌曲库 where 歌手名 = '周杰伦' HAVING sum(播放量) > 20000"; String sql = "select 歌曲名 from 歌曲库 where 歌手名 = '周杰伦' HAVING sum(播放量) > 20000";

View File

@@ -1,10 +1,11 @@
package com.tencent.supersonic.common.util.jsqlparser; package com.tencent.supersonic.common.util.jsqlparser;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.HashSet;
import java.util.Collections;
import java.util.Map;
import java.util.HashMap;
import org.junit.Assert; import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -343,7 +344,19 @@ class SqlParserReplaceHelperTest {
@Test @Test
void replaceFunctionName() { void replaceFunctionName() {
String sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" String sql = "select 公司名称,平均(注册资本),总部地点 from 互联网企业 where\n"
+ "年营业额 >= 28800000000 and 最大(注册资本)>10000 \n"
+ " group by 公司名称 having 平均(注册资本)>10000 order by \n"
+ "平均(注册资本) desc limit 5";
Map<String, String> map = new HashMap<>();
map.put("平均", "avg");
map.put("最大", "max");
sql = SqlParserReplaceHelper.replaceFunction(sql, map);
System.out.println(sql);
Assert.assertEquals("SELECT 公司名称, avg(注册资本), 总部地点 FROM 互联网企业 WHERE 年营业额 >= 28800000000 AND "
+ "max(注册资本) > 10000 GROUP BY 公司名称 HAVING avg(注册资本) > 10000 ORDER BY avg(注册资本) DESC LIMIT 5", sql);
sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where"
+ " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)";
Map<String, String> functionMap = new HashMap<>(); Map<String, String> functionMap = new HashMap<>();
functionMap.put("MONTH".toLowerCase(), "toMonth"); functionMap.put("MONTH".toLowerCase(), "toMonth");