(improvement)(chat) dsl supports revision (#254)

This commit is contained in:
mainmain
2023-10-18 17:44:14 +08:00
committed by GitHub
parent 7b861f563c
commit 7d770d2a6d
16 changed files with 343 additions and 60 deletions

View File

@@ -3,7 +3,8 @@ package com.tencent.supersonic.chat.persistence.dataobject;
public enum CostType {
MAPPER(1, "mapper"),
PARSER(2, "parser"),
QUERY(3, "query");
QUERY(3, "query"),
PARSERRESPONDER(4, "responder");
private Integer type;
private String name;

View File

@@ -14,7 +14,7 @@ public interface ChatParseMapper {
boolean updateParseInfo(ChatParseDO chatParseDO);
ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
ChatParseDO getParseInfo(Long questionId, int parseId);
List<ChatParseDO> getParseInfoList(List<Long> questionIds);

View File

@@ -34,7 +34,7 @@ public interface ChatQueryRepository {
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses);
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
public ChatParseDO getParseInfo(Long questionId, int parseId);
List<ChatParseDO> getParseInfoList(List<Long> questionIds);

View File

@@ -188,9 +188,9 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return chatQueryDOMapper.updateByPrimaryKeyWithBLOBs(chatQueryDO);
}
@Override
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId) {
return chatParseMapper.getParseInfo(questionId, userName, parseId);
public ChatParseDO getParseInfo(Long questionId, int parseId) {
return chatParseMapper.getParseInfo(questionId, parseId);
}
@Override

View File

@@ -62,7 +62,7 @@ public interface ChatService {
Boolean updateQuery(Long questionId, QueryResult queryResult, ChatContext chatCtx);
ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
ChatParseDO getParseInfo(Long questionId, int parseId);
Boolean deleteChatQuery(Long questionId);

View File

@@ -219,8 +219,8 @@ public class ChatServiceImpl implements ChatService {
return tempDate.format(new java.util.Date());
}
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId) {
return chatQueryRepository.getParseInfo(questionId, userName, parseId);
public ChatParseDO getParseInfo(Long questionId, int parseId) {
return chatQueryRepository.getParseInfo(questionId, parseId);
}
public Boolean deleteChatQuery(Long questionId) {

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
@@ -42,25 +43,41 @@ import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.Objects;
import java.util.Set;
import java.util.HashSet;
import java.util.List;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.schema.Column;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.compress.utils.Lists;
@@ -143,7 +160,15 @@ public class QueryServiceImpl implements QueryService {
.build();
}
for (ParseResponder parseResponder : parseResponders) {
Long startTime = System.currentTimeMillis();
parseResponder.fillResponse(parseResult, queryCtx, chatParseDOS);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(parseResponder.getClass().getSimpleName())
.type(CostType.PARSERRESPONDER.getType()).build());
}
if (timeCostDOList.size() > 0) {
saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
}
chatService.updateChatParse(chatParseDOS);
return parseResult;
@@ -157,7 +182,7 @@ public class QueryServiceImpl implements QueryService {
}
private List<SemanticParseInfo> getTop5CandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses) {
List<SemanticParseInfo> candidateParses) {
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
return candidateParses;
}
@@ -182,7 +207,7 @@ public class QueryServiceImpl implements QueryService {
@Override
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(),
queryReq.getUser().getName(), queryReq.getParseId());
queryReq.getParseId());
ChatQueryDO chatQueryDO = chatService.getLastQuery(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>();
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
@@ -303,7 +328,7 @@ public class QueryServiceImpl implements QueryService {
@Override
public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException {
ChatParseDO chatParseDO = chatService.getParseInfo(queryData.getQueryId(),
queryData.getUser().getName(), queryData.getParseId());
queryData.getParseId());
SemanticParseInfo parseInfo = getSemanticParseInfo(queryData, chatParseDO);
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
@@ -316,21 +341,33 @@ public class QueryServiceImpl implements QueryService {
LLMResp llmResp = parseResult.getLlmResp();
String correctorSql = llmResp.getCorrectorSql();
log.info("correctorSql before replacing:{}", correctorSql);
List<FilterExpression> filterExpressionList = SqlParserSelectHelper.getFilterExpression(correctorSql);
updateFilters(filedNameToValueMap, filterExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters());
updateFilters(havingFiledNameToValueMap, filterExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters());
updateDateInfo(queryData, parseInfo, filedNameToValueMap, filterExpressionList);
List<FilterExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
List<FilterExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
updateFilters(filedNameToValueMap, whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
log.info("filedNameToValueMap:{}", filedNameToValueMap);
log.info("removeWhereFieldNames:{}", removeWhereFieldNames);
correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
updateFilters(havingFiledNameToValueMap, havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
log.info("havingFiledNameToValueMap:{}", havingFiledNameToValueMap);
log.info("removeHavingFieldNames:{}", removeHavingFieldNames);
correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
correctorSql = SqlParserRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
log.info("addWhereConditions:{}", addWhereConditions);
log.info("addHavingConditions:{}", addHavingConditions);
correctorSql = SqlParserAddHelper.addWhere(correctorSql, addWhereConditions);
correctorSql = SqlParserAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
llmResp.setCorrectorSql(correctorSql);
parseResult.setLlmResp(llmResp);
@@ -344,34 +381,54 @@ public class QueryServiceImpl implements QueryService {
parseInfo.getSqlInfo().setQuerySql(explain.getSql());
}
}
log.info("parseInfo:{}", JsonUtil.toString(semanticQuery.getParseInfo().getProperties()));
semanticQuery.setParseInfo(parseInfo);
QueryResult queryResult = semanticQuery.execute(user);
queryResult.setChatContext(semanticQuery.getParseInfo());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
return queryResult;
}
@Override
public EntityInfo getEntityInfo(Long queryId, Integer parseId, User user) {
ChatParseDO chatParseDO = chatService.getParseInfo(queryId, user.getName(), parseId);
ChatParseDO chatParseDO = chatService.getParseInfo(queryId, parseId);
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
return semanticService.getEntityInfo(parseInfo, user);
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap, List<FilterExpression> filterExpressionList) {
Map<String, Map<String, String>> filedNameToValueMap,
List<FilterExpression> filterExpressionList,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (Objects.isNull(queryData.getDateInfo())) {
return;
}
Map<String, String> map = new HashMap<>();
String dateField = DateUtils.DATE_FIELD;
if (queryData.getDateInfo().getUnit() > 1) {
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
}
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
for (FilterExpression filterExpression : filterExpressionList) {
if (DateUtils.DATE_FIELD.equals(filterExpression.getFieldName())) {
dateField = filterExpression.getFieldName();
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
if (filterExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) {
dateField = filterExpression.getFieldName();
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
filedNameToValueMap.put(dateField, map);
} else {
removeFieldNames.add(DateUtils.DATE_FIELD);
EqualsTo equalsTo = new EqualsTo();
Column column = new Column(DateUtils.DATE_FIELD);
StringValue stringValue = new StringValue(queryData.getDateInfo().getStartDate());
equalsTo.setLeftExpression(column);
equalsTo.setRightExpression(stringValue);
addConditions.add(equalsTo);
}
break;
}
}
@@ -389,38 +446,133 @@ public class QueryServiceImpl implements QueryService {
map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getEndDate());
}
filedNameToValueMap.put(dateField, map);
if (FilterOperatorEnum.EQUALS.getValue().equals(filterExpression.getOperator())) {
removeFieldNames.add(DateUtils.DATE_FIELD);
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addTimeCondition(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
MinorThanEquals minorThanEquals = new MinorThanEquals();
addTimeCondition(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
}
}
}
}
filedNameToValueMap.put(dateField, map);
parseInfo.setDateInfo(queryData.getDateInfo());
}
public <T extends ComparisonOperator> void addTimeCondition(String date,
T comparisonExpression,
List<Expression> addConditions) {
Column column = new Column(DateUtils.DATE_FIELD);
StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column);
comparisonExpression.setRightExpression(stringValue);
addConditions.add(comparisonExpression);
}
private void updateFilters(Map<String, Map<String, String>> filedNameToValueMap,
List<FilterExpression> filterExpressionList, Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters) {
List<FilterExpression> filterExpressionList,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (CollectionUtils.isEmpty(metricFilters)) {
return;
}
for (QueryFilter queryFilter : metricFilters) {
for (QueryFilter dslQueryFilter : metricFilters) {
Map<String, String> map = new HashMap<>();
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null
&& filterExpression.getFieldName().contains(queryFilter.getName())
&& queryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), queryFilter.getValue().toString());
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(queryFilter.getName())) {
o.setValue(queryFilter.getValue());
&& filterExpression.getFieldName().contains(dslQueryFilter.getName())) {
if (dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())
&& Objects.nonNull(dslQueryFilter.getValue())) {
map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString());
filedNameToValueMap.put(dslQueryFilter.getName(), map);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
}
});
} else {
removeFieldNames.add(dslQueryFilter.getName());
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
EqualsTo equalsTo = new EqualsTo();
addWhereCondition(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addWhereCondition(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
GreaterThan greaterThan = new GreaterThan();
addWhereCondition(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
MinorThanEquals minorThanEquals = new MinorThanEquals();
addWhereCondition(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
MinorThan minorThan = new MinorThan();
addWhereCondition(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
InExpression inExpression = new InExpression();
addWhereInCondition(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
}
});
}
break;
}
}
filedNameToValueMap.put(queryFilter.getName(), map);
}
}
public void addWhereInCondition(QueryFilter dslQueryFilter,
InExpression inExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
Column column = new Column(dslQueryFilter.getName());
ExpressionList expressionList = new ExpressionList();
List<Expression> expressions = new ArrayList<>();
List<String> valueList = JsonUtil.toList(
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
if (CollectionUtils.isEmpty(valueList)) {
return;
}
valueList.stream().forEach(o -> {
StringValue stringValue = new StringValue(o);
expressions.add(stringValue);
});
expressionList.setExpressions(expressions);
inExpression.setLeftExpression(column);
inExpression.setRightItemsList(expressionList);
addConditions.add(inExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
public <T extends ComparisonOperator> void addWhereCondition(QueryFilter dslQueryFilter,
T comparisonExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
}
if (Objects.isNull(dslQueryFilter.getValue())) {
return;
}
Column column = new Column(columnName);
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
comparisonExpression.setLeftExpression(column);
comparisonExpression.setRightExpression(longValue);
addConditions.add(comparisonExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData, ChatParseDO chatParseDO) {
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);

View File

@@ -34,8 +34,7 @@
<select id="getParseInfo" resultMap="ChatParse">
select *
from s2_chat_parse
where question_id = #{questionId} and user_name = #{userName}
and parse_id = #{parseId} limit 1
where question_id = #{questionId} and parse_id = #{parseId} limit 1
</select>
<select id="getParseInfoList" resultMap="ChatParse">

View File

@@ -59,8 +59,8 @@
join (
select distinct chat_id
from s2_chat_query
where query_state = 1 and agent_id = ${agentId}
<if test="userName != null and userName != ''">
where query_state = 1 and agent_id = ${agentId} and (score is null or score > 1)
<if test="userName != null and userName != ''">
and user_name = #{userName}
</if>
order by chat_id desc

View File

@@ -12,13 +12,14 @@ import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
@@ -46,6 +47,14 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
filterExpressions.add(filterExpression);
}
public void visit(InExpression expr) {
FilterExpression filterExpression = new FilterExpression();
filterExpression.setFieldName(((Column) expr.getLeftExpression()).getColumnName());
filterExpression.setOperator(JsqlConstants.IN);
filterExpression.setFieldValue(expr.getRightItemsList());
filterExpressions.add(filterExpression);
}
@Override
public void visit(MinorThan expr) {
FilterExpression filterExpression = getFilterExpression(expr);
@@ -139,4 +148,4 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
}
return null;
}
}
}

View File

@@ -55,6 +55,7 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
public <T extends Expression> void replaceComparisonExpression(T expression) {
Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression();
Expression rightExpression = ((ComparisonOperator) expression).getRightExpression();
if (!(leftExpression instanceof Column || leftExpression instanceof Function)) {
return;
}

View File

@@ -16,5 +16,6 @@ public class JsqlConstants {
public static final String EQUAL_CONSTANT = " 1 = 1 ";
public static final String IN_CONSTANT = " 1 in (1) ";
public static final String IN = "IN";
}
}

View File

@@ -113,6 +113,31 @@ public class SqlParserAddHelper {
return selectStatement.toString();
}
public static String addWhere(String sql, List<Expression> expressionList) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
if (CollectionUtils.isEmpty(expressionList)) {
return sql;
}
Expression expression = expressionList.get(0);
for (int i = 1; i < expressionList.size(); i++) {
expression = new AndExpression(expression, expressionList.get(i));
}
PlainSelect plainSelect = (PlainSelect) selectBody;
Expression where = plainSelect.getWhere();
if (where == null) {
plainSelect.setWhere(expression);
} else {
plainSelect.setWhere(new AndExpression(where, expression));
}
return selectStatement.toString();
}
public static String addAggregateToField(String sql, Map<String, String> fieldNameToAggregate) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
@@ -274,6 +299,31 @@ public class SqlParserAddHelper {
return selectStatement.toString();
}
public static String addHaving(String sql, List<Expression> expressionList) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
if (CollectionUtils.isEmpty(expressionList)) {
return sql;
}
Expression expression = expressionList.get(0);
for (int i = 1; i < expressionList.size(); i++) {
expression = new AndExpression(expression, expressionList.get(i));
}
PlainSelect plainSelect = (PlainSelect) selectBody;
Expression having = plainSelect.getHaving();
if (having == null) {
plainSelect.setHaving(expression);
} else {
plainSelect.setHaving(new AndExpression(having, expression));
}
return selectStatement.toString();
}
public static String addParenthesisToWhere(String sql) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();

View File

@@ -7,8 +7,12 @@ import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.PlainSelect;
@@ -60,17 +64,36 @@ public class SqlParserRemoveHelper {
}
}
public static String getConstant(Expression expression){
String constant = JsqlConstants.EQUAL_CONSTANT;
if (expression instanceof GreaterThanEquals) {
constant = JsqlConstants.GREATER_THAN_EQUALS_CONSTANT;
} else if (expression instanceof MinorThanEquals) {
constant = JsqlConstants.MINOR_THAN_EQUALS_CONSTANT;
} else if (expression instanceof GreaterThan) {
constant = JsqlConstants.GREATER_THAN_CONSTANT;
} else if (expression instanceof MinorThan) {
constant = JsqlConstants.MINOR_THAN_CONSTANT;
}
return constant;
}
private static void removeExpressionWithConstant(Expression expression, Set<String> removeFieldNames) {
if (expression instanceof EqualsTo) {
if (expression instanceof EqualsTo
|| expression instanceof GreaterThanEquals
|| expression instanceof GreaterThan
|| expression instanceof MinorThanEquals
|| expression instanceof MinorThan) {
ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
String columnName = SqlParserSelectHelper.getColumnName(comparisonOperator.getLeftExpression(),
comparisonOperator.getRightExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
String constant = getConstant(expression);
try {
ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(
JsqlConstants.EQUAL_CONSTANT);
constant);
comparisonOperator.setLeftExpression(constantExpression.getLeftExpression());
comparisonOperator.setRightExpression(constantExpression.getRightExpression());
comparisonOperator.setASTNode(constantExpression.getASTNode());
@@ -97,6 +120,22 @@ public class SqlParserRemoveHelper {
}
}
public static String removeHavingCondition(String sql, Set<String> removeFieldNames) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
selectBody.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
removeWhereCondition(plainSelect.getHaving(), removeFieldNames);
}
});
return selectStatement.toString();
}
public static String removeWhere(String sql, List<String> fields) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();

View File

@@ -208,6 +208,19 @@ public class SqlParserSelectHelper {
return null;
}
public static List<FilterExpression> getWhereExpressions(String sql) {
PlainSelect plainSelect = getPlainSelect(sql);
if (Objects.isNull(plainSelect)) {
return new ArrayList<>();
}
Set<FilterExpression> result = new HashSet<>();
Expression where = plainSelect.getWhere();
if (Objects.nonNull(where)) {
where.accept(new FieldAndValueAcquireVisitor(result));
}
return new ArrayList<>(result);
}
public static List<FilterExpression> getHavingExpressions(String sql) {
PlainSelect plainSelect = getPlainSelect(sql);
if (Objects.isNull(plainSelect)) {
@@ -317,6 +330,12 @@ public class SqlParserSelectHelper {
if (leftExpression instanceof Column) {
return ((Column) leftExpression).getColumnName();
}
if (leftExpression instanceof Function) {
List<Expression> expressionList = ((Function) leftExpression).getParameters().getExpressions();
if (!CollectionUtils.isEmpty(expressionList) && expressionList.get(0) instanceof Column) {
return ((Column) expressionList.get(0)).getColumnName();
}
}
if (rightExpression instanceof Column) {
return ((Column) rightExpression).getColumnName();
}

View File

@@ -10,6 +10,18 @@ import org.junit.jupiter.api.Test;
*/
class SqlParserRemoveHelperTest {
@Test
void removeHavingCondition() {
String sql = "select 歌曲名 from 歌曲库 where 歌手名 = '周杰伦' HAVING sum(播放量) > 20000";
Set<String> removeFieldNames = new HashSet<>();
removeFieldNames.add("播放量");
String replaceSql = SqlParserRemoveHelper.removeHavingCondition(sql, removeFieldNames);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE 歌手名 = '周杰伦' HAVING 2 > 1",
replaceSql);
}
@Test
void removeWhereCondition() {
String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "