(improvement)(chat) recall history solved query in every parse

This commit is contained in:
jolunoluo
2023-09-25 17:11:21 +08:00
parent e1772c25c4
commit 34816451c0
9 changed files with 352 additions and 11 deletions

View File

@@ -21,5 +21,4 @@ public class QueryResult {
private SemanticParseInfo chatContext; private SemanticParseInfo chatContext;
private Object response; private Object response;
private List<Map<String, Object>> queryResults; private List<Map<String, Object>> queryResults;
private List<SolvedQueryRecallResp> similarSolvedQuery;
} }

View File

@@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -32,6 +33,11 @@ public class GlobalCorrector extends BaseSemanticCorrector {
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
if (SqlParserSelectHelper.hasGroupBy(semanticCorrectInfo.getSql())) {
return;
}
} }
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {

View File

@@ -95,7 +95,7 @@ public class DefaultQueryResponder implements QueryResponder {
} }
} }
} catch (Exception e) { } catch (Exception e) {
log.warn("recall similar solved query failed", e); log.warn("recall similar solved query failed, queryText:{}", queryText);
} }
return solvedQueryRecallResps; return solvedQueryRecallResps;
} }

View File

@@ -143,15 +143,15 @@ public class QueryServiceImpl implements QueryService {
saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(), saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue()); queryReq.getUser().getName(), queryReq.getChatId().longValue());
} else { } else {
List<SolvedQueryRecallResp> solvedQueryRecallResps =
queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText());
parseResult = ParseResp.builder() parseResult = ParseResp.builder()
.chatId(queryReq.getChatId()) .chatId(queryReq.getChatId())
.queryText(queryReq.getQueryText()) .queryText(queryReq.getQueryText())
.state(ParseResp.ParseState.FAILED) .state(ParseResp.ParseState.FAILED)
.similarSolvedQuery(solvedQueryRecallResps)
.build(); .build();
} }
List<SolvedQueryRecallResp> solvedQueryRecallResps =
queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText());
parseResult.setSimilarSolvedQuery(solvedQueryRecallResps);
return parseResult; return parseResult;
} }

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.common.util.jsqlparser;
import lombok.Data;
@Data
public class FiledExpression {
private String operator;
private String fieldName;
}

View File

@@ -0,0 +1,113 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
private List<Expression> waitingForAdds = new ArrayList<>();
private Set<String> fieldNames;
public FiledFilterReplaceVisitor(Set<String> fieldNames) {
this.fieldNames = fieldNames;
}
@Override
public void visit(MinorThan expr) {
List<Expression> expressions = parserFilter(expr);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
@Override
public void visit(EqualsTo expr) {
List<Expression> expressions = parserFilter(expr);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
@Override
public void visit(MinorThanEquals expr) {
List<Expression> expressions = parserFilter(expr);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
@Override
public void visit(GreaterThan expr) {
List<Expression> expressions = parserFilter(expr);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
@Override
public void visit(GreaterThanEquals expr) {
List<Expression> expressions = parserFilter(expr);
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
public List<Expression> getWaitingForAdds() {
return waitingForAdds;
}
public List<Expression> parserFilter(ComparisonOperator comparisonOperator) {
List<Expression> result = new ArrayList<>();
String toString = comparisonOperator.toString();
Expression leftExpression = comparisonOperator.getLeftExpression();
if (!(leftExpression instanceof Function)) {
return result;
}
Function leftExpressionFunction = (Function) leftExpression;
if (leftExpressionFunction.toString().contains(DateFunctionHelper.DATE_FUNCTION)) {
return result;
}
List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
if (CollectionUtils.isEmpty(leftExpressions)) {
return result;
}
Column field = (Column) leftExpressions.get(0);
String columnName = field.getColumnName();
if (!fieldNames.contains(columnName)) {
return null;
}
try {
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 ");
comparisonOperator.setLeftExpression(expression.getLeftExpression());
comparisonOperator.setRightExpression(expression.getRightExpression());
comparisonOperator.setASTNode(expression.getASTNode());
result.add(CCJSqlParserUtil.parseCondExpression(toString));
return result;
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}
return null;
}
}

View File

@@ -205,6 +205,30 @@ public class SqlParserSelectHelper {
public static boolean hasAggregateFunction(String sql) { public static boolean hasAggregateFunction(String sql) {
if (hasFunction(sql)) {
return true;
}
return hasGroupBy(sql);
}
public static boolean hasGroupBy(String sql) {
Select selectStatement = getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return false;
}
PlainSelect plainSelect = (PlainSelect) selectBody;
GroupByElement groupBy = plainSelect.getGroupBy();
if (Objects.nonNull(groupBy)) {
GroupByVisitor replaceVisitor = new GroupByVisitor();
groupBy.accept(replaceVisitor);
return replaceVisitor.isHasAggregateFunction();
}
return false;
}
public static boolean hasFunction(String sql) {
Select selectStatement = getSelect(sql); Select selectStatement = getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody(); SelectBody selectBody = selectStatement.getSelectBody();
@@ -221,12 +245,6 @@ public class SqlParserSelectHelper {
if (selectFunction) { if (selectFunction) {
return true; return true;
} }
GroupByElement groupBy = plainSelect.getGroupBy();
if (Objects.nonNull(groupBy)) {
GroupByVisitor replaceVisitor = new GroupByVisitor();
groupBy.accept(replaceVisitor);
return replaceVisitor.isHasAggregateFunction();
}
return false; return false;
} }
} }

View File

@@ -11,6 +11,7 @@ import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.GroupByElement;
@@ -20,6 +21,7 @@ import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import net.sf.jsqlparser.util.SelectUtils; import net.sf.jsqlparser.util.SelectUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@@ -278,5 +280,114 @@ public class SqlParserUpdateHelper {
return selectStatement.toString(); return selectStatement.toString();
} }
public static String addAggregateToField(String sql, Map<String, String> fieldNameToAggregate) {
if (SqlParserSelectHelper.hasGroupBy(sql)) {
return sql;
}
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
selectBody.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate);
addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate);
}
});
return selectStatement.toString();
}
public static String addGroupBy(String sql, List<String> groupByFields) {
if (SqlParserSelectHelper.hasGroupBy(sql)) {
return sql;
}
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectBody;
GroupByElement groupByElement = new GroupByElement();
for (String groupByField : groupByFields) {
groupByElement.addGroupByExpression(new Column(groupByField));
}
plainSelect.setGroupByElement(groupByElement);
return selectStatement.toString();
}
private static void addAggregateToSelectItems(List<SelectItem> selectItems,
Map<String, String> fieldNameToAggregate) {
for (SelectItem selectItem : selectItems) {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
Expression expression = selectExpressionItem.getExpression();
String columnName = ((Column) expression).getColumnName();
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
if (Objects.isNull(function)) {
continue;
}
selectExpressionItem.setExpression(function);
}
}
}
private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements,
Map<String, String> fieldNameToAggregate) {
if (orderByElements == null) {
return;
}
for (OrderByElement orderByElement : orderByElements) {
Expression expression = orderByElement.getExpression();
String columnName = ((Column) expression).getColumnName();
if (StringUtils.isEmpty(columnName)) {
continue;
}
Function function = getFunction(expression, fieldNameToAggregate.get(columnName));
if (Objects.isNull(function)) {
continue;
}
orderByElement.setExpression(function);
}
}
private static Function getFunction(Expression expression, String aggregateName) {
if (StringUtils.isEmpty(aggregateName)) {
return null;
}
Function sumFunction = new Function();
sumFunction.setName(aggregateName);
sumFunction.setParameters(new ExpressionList(expression));
return sumFunction;
}
public static String addHaving(String sql, Set<String> fieldNames) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectBody;
//replace metric to 1 and 1 and add having metric
Expression where = plainSelect.getWhere();
FiledFilterReplaceVisitor visitor = new FiledFilterReplaceVisitor(fieldNames);
if (Objects.nonNull(where)) {
where.accept(visitor);
}
List<Expression> waitingForAdds = visitor.getWaitingForAdds();
if (!CollectionUtils.isEmpty(waitingForAdds)) {
for (Expression waitingForAdd : waitingForAdds) {
plainSelect.setHaving(waitingForAdd);
}
}
return selectStatement.toString();
}
} }

View File

@@ -1,9 +1,12 @@
package com.tencent.supersonic.common.util.jsqlparser; package com.tencent.supersonic.common.util.jsqlparser;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.parser.CCJSqlParserUtil;
@@ -266,6 +269,85 @@ class SqlParserUpdateHelperTest {
} }
@Test
void addAggregateToField() {
String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND "
+ "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
System.out.println(replaceSql);
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' "
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
replaceSql);
sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND "
+ "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000";
havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
System.out.println(replaceSql);
Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' "
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000",
replaceSql);
}
@Test
void addAggregateToMetricField() {
String sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' order by pv desc limit 10";
Map<String, String> filedNameToAggregate = new HashMap<>();
filedNameToAggregate.put("pv", "sum");
List<String> groupByFields = new ArrayList<>();
groupByFields.add("department");
String replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate);
replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
}
@Test
void addGroupBy() {
String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' "
+ "order by sum(pv) desc limit 10";
List<String> groupByFields = new ArrayList<>();
groupByFields.add("department");
String replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
}
@Test
void addHaving() {
String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and "
+ "sum(pv) > 2000 group by department order by sum(pv) desc limit 10";
List<String> groupByFields = new ArrayList<>();
groupByFields.add("department");
Set<String> fieldNames = new HashSet<>();
fieldNames.add("pv");
String replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames);
Assert.assertEquals(
"SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
+ "AND 1 > 1 GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
replaceSql);
}
private Map<String, String> initParams() { private Map<String, String> initParams() {
Map<String, String> fieldToBizName = new HashMap<>(); Map<String, String> fieldToBizName = new HashMap<>();
fieldToBizName.put("部门", "department"); fieldToBizName.put("部门", "department");