From 7fce9bacc2067553bd055752e8443be72244d70f Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Tue, 12 Sep 2023 13:21:23 +0800 Subject: [PATCH] (improvement)(chat) two display modes are supported in dsl: group by and details (#74) --- .../chat/parser/llm/dsl/LLMDslParser.java | 55 +++++++------ .../chat/parser/llm/dsl/LLMDslParserTest.java | 2 +- .../jsqlparser/SqlParserSelectHelper.java | 79 +++++++++++-------- .../jsqlparser/SqlParserSelectHelperTest.java | 13 ++- 4 files changed, 86 insertions(+), 63 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java index 6678f0849..e9984407c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java @@ -102,30 +102,13 @@ public class LLMDslParser implements SemanticParser { llmResp.setCorrectorSql(semanticCorrectInfo.getSql()); - setFilter(semanticCorrectInfo, modelId, parseInfo); - - setDimensionsAndMetrics(modelId, parseInfo, semanticCorrectInfo.getSql()); + updateParseInfo(semanticCorrectInfo, modelId, parseInfo); } catch (Exception e) { log.error("LLMDSLParser error", e); } } - private void setDimensionsAndMetrics(Long modelId, SemanticParseInfo parseInfo, String sql) { - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - - if (Objects.isNull(semanticSchema)) { - return; - } - List allFields = getFieldsExceptDate(sql); - - Set metrics = getElements(modelId, allFields, semanticSchema.getMetrics()); - parseInfo.setMetrics(metrics); - - Set dimensions = getElements(modelId, allFields, semanticSchema.getDimensions()); - parseInfo.setDimensions(dimensions); - } - private Set getElements(Long modelId, List allFields, List elements) { return elements.stream() .filter(schemaElement -> modelId.equals(schemaElement.getModel()) @@ -133,8 +116,7 @@ public class LLMDslParser implements SemanticParser { ).collect(Collectors.toSet()); } - private List getFieldsExceptDate(String sql) { - List allFields = SqlParserSelectHelper.getAllFields(sql); + private List getFieldsExceptDate(List allFields) { if (CollectionUtils.isEmpty(allFields)) { return new ArrayList<>(); } @@ -143,20 +125,19 @@ public class LLMDslParser implements SemanticParser { .collect(Collectors.toList()); } - public void setFilter(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) { + public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) { String correctorSql = semanticCorrectInfo.getPreSql(); if (StringUtils.isEmpty(correctorSql)) { correctorSql = semanticCorrectInfo.getSql(); } List expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); - if (CollectionUtils.isEmpty(expressions)) { - return; - } //set dataInfo try { - DateConf dateInfo = getDateInfo(expressions); - parseInfo.setDateInfo(dateInfo); + if (!CollectionUtils.isEmpty(expressions)) { + DateConf dateInfo = getDateInfo(expressions); + parseInfo.setDateInfo(dateInfo); + } } catch (Exception e) { log.error("set dateInfo error :", e); } @@ -169,6 +150,28 @@ public class LLMDslParser implements SemanticParser { } catch (Exception e) { log.error("set dimensionFilter error :", e); } + + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + + if (Objects.isNull(semanticSchema)) { + return; + } + List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql())); + + Set metrics = getElements(modelId, allFields, semanticSchema.getMetrics()); + parseInfo.setMetrics(metrics); + + if (SqlParserSelectHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) { + parseInfo.setNativeQuery(false); + List groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql()); + List groupByDimensions = getFieldsExceptDate(groupByFields); + parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions())); + } else { + parseInfo.setNativeQuery(true); + List selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql()); + List selectDimensions = getFieldsExceptDate(selectFields); + parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions())); + } } private List getDimensionFilter(Map bizNameToElement, diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParserTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParserTest.java index 33e86c177..2cd01f0be 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParserTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParserTest.java @@ -74,7 +74,7 @@ class LLMDslParserTest { LLMDslParser llmDslParser = new LLMDslParser(); - llmDslParser.setFilter(semanticCorrectInfo, 2L, parseInfo); + llmDslParser.updateParseInfo(semanticCorrectInfo, 2L, parseInfo); } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index ac3f6a6d6..1c0ba174f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -8,7 +8,6 @@ import java.util.Set; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; @@ -46,27 +45,17 @@ public class SqlParserSelectHelper { return new ArrayList<>(); } Set result = new HashSet<>(); + getWhereFields(plainSelect, result); + return new ArrayList<>(result); + } + + private static void getWhereFields(PlainSelect plainSelect, Set result) { Expression where = plainSelect.getWhere(); if (Objects.nonNull(where)) { where.accept(new FieldAcquireVisitor(result)); } - return new ArrayList<>(result); } - public static List getOrderByFields(String sql) { - PlainSelect plainSelect = getPlainSelect(sql); - if (Objects.isNull(plainSelect)) { - return new ArrayList<>(); - } - Set result = new HashSet<>(); - List orderByElements = plainSelect.getOrderByElements(); - if (!CollectionUtils.isEmpty(orderByElements)) { - for (OrderByElement orderByElement : orderByElements) { - orderByElement.accept(new OrderByAcquireVisitor(result)); - } - } - return new ArrayList<>(result); - } public static List getSelectFields(String sql) { PlainSelect plainSelect = getPlainSelect(sql); @@ -122,6 +111,45 @@ public class SqlParserSelectHelper { } Set result = getSelectFields(plainSelect); + getGroupByFields(plainSelect, result); + + getOrderByFields(plainSelect, result); + + getWhereFields(plainSelect, result); + + return new ArrayList<>(result); + } + + public static List getOrderByFields(String sql) { + PlainSelect plainSelect = getPlainSelect(sql); + if (Objects.isNull(plainSelect)) { + return new ArrayList<>(); + } + Set result = new HashSet<>(); + getOrderByFields(plainSelect, result); + return new ArrayList<>(result); + } + + private static void getOrderByFields(PlainSelect plainSelect, Set result) { + List orderByElements = plainSelect.getOrderByElements(); + if (!CollectionUtils.isEmpty(orderByElements)) { + for (OrderByElement orderByElement : orderByElements) { + orderByElement.accept(new OrderByAcquireVisitor(result)); + } + } + } + + public static List getGroupByFields(String sql) { + PlainSelect plainSelect = getPlainSelect(sql); + if (Objects.isNull(plainSelect)) { + return new ArrayList<>(); + } + HashSet result = new HashSet<>(); + getGroupByFields(plainSelect, result); + return new ArrayList<>(result); + } + + private static void getGroupByFields(PlainSelect plainSelect, Set result) { GroupByElement groupBy = plainSelect.getGroupBy(); if (groupBy != null) { List groupByExpressions = groupBy.getGroupByExpressions(); @@ -132,24 +160,6 @@ public class SqlParserSelectHelper { } } } - List orderByElements = plainSelect.getOrderByElements(); - if (!CollectionUtils.isEmpty(orderByElements)) { - for (OrderByElement orderByElement : orderByElements) { - orderByElement.accept(new OrderByAcquireVisitor(result)); - } - } - - Expression where = plainSelect.getWhere(); - if (where != null) { - where.accept(new ExpressionVisitorAdapter() { - @Override - public void visit(Column column) { - result.add(column.getColumnName()); - } - }); - } - - return new ArrayList<>(result); } public static String getTableName(String sql) { @@ -190,6 +200,5 @@ public class SqlParserSelectHelper { } return false; } - } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index df3bc6628..b5d84c1a7 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -210,7 +210,6 @@ class SqlParserSelectHelperTest { hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql); Assert.assertEquals(hasAggregateFunction, false); - sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' " + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10"; hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql); @@ -232,4 +231,16 @@ class SqlParserSelectHelperTest { fieldToBizName.put("转3.0前后30天结算份额衰减", "fdafdfdsa_fdas"); return fieldToBizName; } + + @Test + void getGroupByFields() { + + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; + List selectFields = SqlParserSelectHelper.getGroupByFields(sql); + + Assert.assertEquals(selectFields.contains("部门"), true); + + } + }