From 42c0bea8fc7f9d755fc9ce0aa73f99071273d09d Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Wed, 22 Nov 2023 15:05:10 +0800 Subject: [PATCH] [improvement](chat) rule and llm support replace metric (#415) * [improvement] replace metric * [improvement] replace metric * [improvement] supports replace metric --------- Co-authored-by: zuopengge --- .../chat/service/impl/QueryServiceImpl.java | 16 ++++---- .../jsqlparser/SqlParserRemoveHelper.java | 23 ----------- .../jsqlparser/SqlParserReplaceHelper.java | 41 +++++++++++++++++++ .../SqlParserReplaceHelperTest.java | 21 ++++++++++ 4 files changed, 71 insertions(+), 30 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index c2e47eb53..a0a16d9e4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -154,7 +154,6 @@ public class QueryServiceImpl implements QueryService { }); } } - //5. postProcessor postProcessors.forEach(postProcessor -> { long startTime = System.currentTimeMillis(); @@ -163,7 +162,6 @@ public class QueryServiceImpl implements QueryService { .interfaceName(postProcessor.getClass().getSimpleName()) .type(CostType.POSTPROCESSOR.getType()).build()); }); - //6. responder parseResponders.forEach(parseResponder -> { long startTime = System.currentTimeMillis(); @@ -351,8 +349,12 @@ public class QueryServiceImpl implements QueryService { .map(o -> o.getName()).collect(Collectors.toList()); String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); log.info("before replaceMetrics:{}", correctorSql); - correctorSql = SqlParserAddHelper.addFieldsToSelect(correctorSql, metrics); - correctorSql = SqlParserRemoveHelper.removeSelect(correctorSql, filteredMetrics); + log.info("filteredMetrics:{},metrics:{}", filteredMetrics, metrics); + Map fieldMap = new HashMap<>(); + if (CollectionUtils.isNotEmpty(filteredMetrics) && CollectionUtils.isNotEmpty(metrics)) { + fieldMap.put(filteredMetrics.get(0), metrics.get(0)); + correctorSql = SqlParserReplaceHelper.replaceSelectFields(correctorSql, fieldMap); + } log.info("after replaceMetrics:{}", correctorSql); parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); } @@ -547,9 +549,9 @@ public class QueryServiceImpl implements QueryService { if (CollectionUtils.isNotEmpty(queryData.getDimensions())) { parseInfo.setDimensions(queryData.getDimensions()); } - if (CollectionUtils.isNotEmpty(queryData.getMetrics())) { - parseInfo.setMetrics(queryData.getMetrics()); - } + //if (CollectionUtils.isNotEmpty(queryData.getMetrics())) { + // parseInfo.setMetrics(queryData.getMetrics()); + //} if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) { parseInfo.setDimensionFilters(queryData.getDimensionFilters()); } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java index c0f5946b2..2e6736c42 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserRemoveHelper.java @@ -39,29 +39,6 @@ import java.util.Set; @Slf4j public class SqlParserRemoveHelper { - public static String removeSelect(String sql, List filteredMetrics) { - Select selectStatement = SqlParserSelectHelper.getSelect(sql); - SelectBody selectBody = selectStatement.getSelectBody(); - - if (!(selectBody instanceof PlainSelect)) { - return sql; - } - List selectItemList = ((PlainSelect) selectBody).getSelectItems(); - selectItemList.removeIf(o -> { - Expression expression = ((SelectExpressionItem) o).getExpression(); - if (expression instanceof Column) { - Column column = (Column) expression; - String columnName = column.getColumnName(); - if (filteredMetrics.contains(columnName)) { - return true; - } - } - return false; - }); - ((PlainSelect) selectBody).setSelectItems(selectItemList); - return selectStatement.toString(); - } - public static String removeSelect(String sql, Set fields) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); if (selectStatement == null) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java index 03aade580..ed9c05f17 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java @@ -4,6 +4,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.ArrayList; + import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Alias; import net.sf.jsqlparser.expression.Expression; @@ -17,6 +19,7 @@ import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; +import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.OrderByElement; @@ -36,6 +39,44 @@ import org.springframework.util.CollectionUtils; @Slf4j public class SqlParserReplaceHelper { + public static String replaceSelectFields(String sql, Map fieldNameMap) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + ((PlainSelect) selectBody).getSelectItems().stream().forEach(o -> { + SelectExpressionItem selectExpressionItem = (SelectExpressionItem) o; + String alias = ""; + if (selectExpressionItem.getExpression() instanceof Function) { + Function function = (Function) selectExpressionItem.getExpression(); + Column column = (Column) function.getParameters().getExpressions().get(0); + if (fieldNameMap.containsKey(column.getColumnName())) { + String value = fieldNameMap.get(column.getColumnName()); + alias = value; + List expressions = new ArrayList<>(); + expressions.add(new Column(value)); + function.getParameters().setExpressions(expressions); + } + } + if (selectExpressionItem.getExpression() instanceof Column) { + Column column = (Column) selectExpressionItem.getExpression(); + String columnName = column.getColumnName(); + if (fieldNameMap.containsKey(columnName)) { + String value = fieldNameMap.get(columnName); + alias = value; + if (StringUtils.isNotBlank(value)) { + selectExpressionItem.setExpression(new Column(value)); + } + } + } + if (Objects.nonNull(selectExpressionItem.getAlias()) && StringUtils.isNotBlank(alias)) { + selectExpressionItem.getAlias().setName(alias); + } + }); + return selectStatement.toString(); + } + public static String replaceValue(String sql, Map> filedNameToValueMap) { return replaceValue(sql, filedNameToValueMap, true); } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java index 8240209e6..c0acc55bb 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelperTest.java @@ -14,6 +14,27 @@ import org.junit.jupiter.api.Test; */ class SqlParserReplaceHelperTest { + @Test + void replaceSelectField() { + + String sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1"; + Map fieldMap = new HashMap<>(); + fieldMap.put("播放量", "播放量1"); + sql = SqlParserReplaceHelper.replaceSelectFields(sql, fieldMap); + System.out.println(sql); + Assert.assertEquals("SELECT 维度1, sum(播放量1) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql); + + sql = "SELECT 维度1,播放量 FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1"; + fieldMap = new HashMap<>(); + fieldMap.put("播放量", "播放量1"); + sql = SqlParserReplaceHelper.replaceSelectFields(sql, fieldMap); + System.out.println(sql); + Assert.assertEquals("SELECT 维度1, 播放量1 FROM 数据库 WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql); + } + @Test void replaceValue() {