mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) two display modes are supported in dsl: group by and details (#74)
This commit is contained in:
@@ -102,30 +102,13 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
|
|
||||||
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
|
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
|
||||||
|
|
||||||
setFilter(semanticCorrectInfo, modelId, parseInfo);
|
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
|
||||||
|
|
||||||
setDimensionsAndMetrics(modelId, parseInfo, semanticCorrectInfo.getSql());
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("LLMDSLParser error", e);
|
log.error("LLMDSLParser error", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setDimensionsAndMetrics(Long modelId, SemanticParseInfo parseInfo, String sql) {
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
|
|
||||||
if (Objects.isNull(semanticSchema)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<String> allFields = getFieldsExceptDate(sql);
|
|
||||||
|
|
||||||
Set<SchemaElement> metrics = getElements(modelId, allFields, semanticSchema.getMetrics());
|
|
||||||
parseInfo.setMetrics(metrics);
|
|
||||||
|
|
||||||
Set<SchemaElement> dimensions = getElements(modelId, allFields, semanticSchema.getDimensions());
|
|
||||||
parseInfo.setDimensions(dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
||||||
return elements.stream()
|
return elements.stream()
|
||||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
||||||
@@ -133,8 +116,7 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
).collect(Collectors.toSet());
|
).collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<String> getFieldsExceptDate(String sql) {
|
private List<String> getFieldsExceptDate(List<String> allFields) {
|
||||||
List<String> allFields = SqlParserSelectHelper.getAllFields(sql);
|
|
||||||
if (CollectionUtils.isEmpty(allFields)) {
|
if (CollectionUtils.isEmpty(allFields)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
@@ -143,20 +125,19 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setFilter(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
|
public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
|
||||||
|
|
||||||
String correctorSql = semanticCorrectInfo.getPreSql();
|
String correctorSql = semanticCorrectInfo.getPreSql();
|
||||||
if (StringUtils.isEmpty(correctorSql)) {
|
if (StringUtils.isEmpty(correctorSql)) {
|
||||||
correctorSql = semanticCorrectInfo.getSql();
|
correctorSql = semanticCorrectInfo.getSql();
|
||||||
}
|
}
|
||||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
||||||
if (CollectionUtils.isEmpty(expressions)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
//set dataInfo
|
//set dataInfo
|
||||||
try {
|
try {
|
||||||
DateConf dateInfo = getDateInfo(expressions);
|
if (!CollectionUtils.isEmpty(expressions)) {
|
||||||
parseInfo.setDateInfo(dateInfo);
|
DateConf dateInfo = getDateInfo(expressions);
|
||||||
|
parseInfo.setDateInfo(dateInfo);
|
||||||
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("set dateInfo error :", e);
|
log.error("set dateInfo error :", e);
|
||||||
}
|
}
|
||||||
@@ -169,6 +150,28 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("set dimensionFilter error :", e);
|
log.error("set dimensionFilter error :", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
|
||||||
|
if (Objects.isNull(semanticSchema)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql()));
|
||||||
|
|
||||||
|
Set<SchemaElement> metrics = getElements(modelId, allFields, semanticSchema.getMetrics());
|
||||||
|
parseInfo.setMetrics(metrics);
|
||||||
|
|
||||||
|
if (SqlParserSelectHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) {
|
||||||
|
parseInfo.setNativeQuery(false);
|
||||||
|
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql());
|
||||||
|
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||||
|
parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions()));
|
||||||
|
} else {
|
||||||
|
parseInfo.setNativeQuery(true);
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql());
|
||||||
|
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||||
|
parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement,
|
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement,
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class LLMDslParserTest {
|
|||||||
|
|
||||||
LLMDslParser llmDslParser = new LLMDslParser();
|
LLMDslParser llmDslParser = new LLMDslParser();
|
||||||
|
|
||||||
llmDslParser.setFilter(semanticCorrectInfo, 2L, parseInfo);
|
llmDslParser.updateParseInfo(semanticCorrectInfo, 2L, parseInfo);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -8,7 +8,6 @@ import java.util.Set;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
import net.sf.jsqlparser.schema.Table;
|
import net.sf.jsqlparser.schema.Table;
|
||||||
@@ -46,27 +45,17 @@ public class SqlParserSelectHelper {
|
|||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
Set<String> result = new HashSet<>();
|
Set<String> result = new HashSet<>();
|
||||||
|
getWhereFields(plainSelect, result);
|
||||||
|
return new ArrayList<>(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void getWhereFields(PlainSelect plainSelect, Set<String> result) {
|
||||||
Expression where = plainSelect.getWhere();
|
Expression where = plainSelect.getWhere();
|
||||||
if (Objects.nonNull(where)) {
|
if (Objects.nonNull(where)) {
|
||||||
where.accept(new FieldAcquireVisitor(result));
|
where.accept(new FieldAcquireVisitor(result));
|
||||||
}
|
}
|
||||||
return new ArrayList<>(result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<String> getOrderByFields(String sql) {
|
|
||||||
PlainSelect plainSelect = getPlainSelect(sql);
|
|
||||||
if (Objects.isNull(plainSelect)) {
|
|
||||||
return new ArrayList<>();
|
|
||||||
}
|
|
||||||
Set<String> result = new HashSet<>();
|
|
||||||
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
|
|
||||||
if (!CollectionUtils.isEmpty(orderByElements)) {
|
|
||||||
for (OrderByElement orderByElement : orderByElements) {
|
|
||||||
orderByElement.accept(new OrderByAcquireVisitor(result));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return new ArrayList<>(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static List<String> getSelectFields(String sql) {
|
public static List<String> getSelectFields(String sql) {
|
||||||
PlainSelect plainSelect = getPlainSelect(sql);
|
PlainSelect plainSelect = getPlainSelect(sql);
|
||||||
@@ -122,6 +111,45 @@ public class SqlParserSelectHelper {
|
|||||||
}
|
}
|
||||||
Set<String> result = getSelectFields(plainSelect);
|
Set<String> result = getSelectFields(plainSelect);
|
||||||
|
|
||||||
|
getGroupByFields(plainSelect, result);
|
||||||
|
|
||||||
|
getOrderByFields(plainSelect, result);
|
||||||
|
|
||||||
|
getWhereFields(plainSelect, result);
|
||||||
|
|
||||||
|
return new ArrayList<>(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<String> getOrderByFields(String sql) {
|
||||||
|
PlainSelect plainSelect = getPlainSelect(sql);
|
||||||
|
if (Objects.isNull(plainSelect)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
Set<String> result = new HashSet<>();
|
||||||
|
getOrderByFields(plainSelect, result);
|
||||||
|
return new ArrayList<>(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void getOrderByFields(PlainSelect plainSelect, Set<String> result) {
|
||||||
|
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
|
||||||
|
if (!CollectionUtils.isEmpty(orderByElements)) {
|
||||||
|
for (OrderByElement orderByElement : orderByElements) {
|
||||||
|
orderByElement.accept(new OrderByAcquireVisitor(result));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<String> getGroupByFields(String sql) {
|
||||||
|
PlainSelect plainSelect = getPlainSelect(sql);
|
||||||
|
if (Objects.isNull(plainSelect)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
HashSet<String> result = new HashSet<>();
|
||||||
|
getGroupByFields(plainSelect, result);
|
||||||
|
return new ArrayList<>(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void getGroupByFields(PlainSelect plainSelect, Set<String> result) {
|
||||||
GroupByElement groupBy = plainSelect.getGroupBy();
|
GroupByElement groupBy = plainSelect.getGroupBy();
|
||||||
if (groupBy != null) {
|
if (groupBy != null) {
|
||||||
List<Expression> groupByExpressions = groupBy.getGroupByExpressions();
|
List<Expression> groupByExpressions = groupBy.getGroupByExpressions();
|
||||||
@@ -132,24 +160,6 @@ public class SqlParserSelectHelper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
|
|
||||||
if (!CollectionUtils.isEmpty(orderByElements)) {
|
|
||||||
for (OrderByElement orderByElement : orderByElements) {
|
|
||||||
orderByElement.accept(new OrderByAcquireVisitor(result));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Expression where = plainSelect.getWhere();
|
|
||||||
if (where != null) {
|
|
||||||
where.accept(new ExpressionVisitorAdapter() {
|
|
||||||
@Override
|
|
||||||
public void visit(Column column) {
|
|
||||||
result.add(column.getColumnName());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return new ArrayList<>(result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String getTableName(String sql) {
|
public static String getTableName(String sql) {
|
||||||
@@ -190,6 +200,5 @@ public class SqlParserSelectHelper {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -210,7 +210,6 @@ class SqlParserSelectHelperTest {
|
|||||||
hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql);
|
hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql);
|
||||||
Assert.assertEquals(hasAggregateFunction, false);
|
Assert.assertEquals(hasAggregateFunction, false);
|
||||||
|
|
||||||
|
|
||||||
sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' "
|
sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' "
|
||||||
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10";
|
+ "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10";
|
||||||
hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql);
|
hasAggregateFunction = SqlParserSelectHelper.hasAggregateFunction(sql);
|
||||||
@@ -232,4 +231,16 @@ class SqlParserSelectHelperTest {
|
|||||||
fieldToBizName.put("转3.0前后30天结算份额衰减", "fdafdfdsa_fdas");
|
fieldToBizName.put("转3.0前后30天结算份额衰减", "fdafdfdsa_fdas");
|
||||||
return fieldToBizName;
|
return fieldToBizName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void getGroupByFields() {
|
||||||
|
|
||||||
|
String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'"
|
||||||
|
+ " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1";
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getGroupByFields(sql);
|
||||||
|
|
||||||
|
Assert.assertEquals(selectFields.contains("部门"), true);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user