mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(Headless) If it is in DETAIL mode and select *, add default metrics and dimensions. (#1186)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -24,6 +25,7 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
|||||||
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
|
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
|
import net.sf.jsqlparser.statement.select.AllColumns;
|
||||||
import net.sf.jsqlparser.statement.select.GroupByElement;
|
import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||||
import net.sf.jsqlparser.statement.select.Select;
|
import net.sf.jsqlparser.statement.select.Select;
|
||||||
@@ -37,20 +39,24 @@ import org.springframework.util.CollectionUtils;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class SqlRemoveHelper {
|
public class SqlRemoveHelper {
|
||||||
|
|
||||||
public static String removeSelect(String sql, Set<String> fields) {
|
public static String removeAsteriskAndAddFields(String sql, Set<String> needAddDefaultFields) {
|
||||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
if (selectStatement == null) {
|
if (Objects.isNull(selectStatement)) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
//SelectBody selectBody = selectStatement.getSelectBody();
|
|
||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
List<SelectItem<?>> selectItems = ((PlainSelect) selectStatement).getSelectItems();
|
List<SelectItem<?>> selectItems = ((PlainSelect) selectStatement).getSelectItems();
|
||||||
selectItems.removeIf(selectItem -> {
|
if (selectItems.stream().anyMatch(item -> item.getExpression() instanceof AllColumns)) {
|
||||||
String columnName = SqlSelectHelper.getColumnName(selectItem.getExpression());
|
selectItems.clear();
|
||||||
return fields.contains(columnName);
|
List<SelectItem<Column>> columnSelectItems = new ArrayList<>();
|
||||||
});
|
for (String fieldName : needAddDefaultFields) {
|
||||||
|
SelectItem<Column> selectExpressionItem = new SelectItem(new Column(fieldName));
|
||||||
|
columnSelectItems.add(selectExpressionItem);
|
||||||
|
}
|
||||||
|
selectItems.addAll(columnSelectItems);
|
||||||
|
}
|
||||||
return selectStatement.toString();
|
return selectStatement.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +65,6 @@ public class SqlRemoveHelper {
|
|||||||
if (selectStatement == null) {
|
if (selectStatement == null) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
//SelectBody selectBody = selectStatement.getSelectBody();
|
|
||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
@@ -79,8 +84,6 @@ public class SqlRemoveHelper {
|
|||||||
|
|
||||||
public static String removeWhereCondition(String sql, Set<String> removeFieldNames) {
|
public static String removeWhereCondition(String sql, Set<String> removeFieldNames) {
|
||||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
//SelectBody selectBody = selectStatement.getSelectBody();
|
|
||||||
|
|
||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
@@ -105,8 +108,6 @@ public class SqlRemoveHelper {
|
|||||||
if (selectStatement == null) {
|
if (selectStatement == null) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
//SelectBody selectBody = selectStatement.getSelectBody();
|
|
||||||
|
|
||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
@@ -184,7 +185,6 @@ public class SqlRemoveHelper {
|
|||||||
InExpression constantExpression = (InExpression) CCJSqlParserUtil.parseCondExpression(
|
InExpression constantExpression = (InExpression) CCJSqlParserUtil.parseCondExpression(
|
||||||
JsqlConstants.IN_CONSTANT);
|
JsqlConstants.IN_CONSTANT);
|
||||||
inExpression.setLeftExpression(constantExpression.getLeftExpression());
|
inExpression.setLeftExpression(constantExpression.getLeftExpression());
|
||||||
//inExpression.setRightItemsList(constantExpression.getRightItemsList());
|
|
||||||
inExpression.setRightExpression(constantExpression.getRightExpression());
|
inExpression.setRightExpression(constantExpression.getRightExpression());
|
||||||
inExpression.setASTNode(constantExpression.getASTNode());
|
inExpression.setASTNode(constantExpression.getASTNode());
|
||||||
} catch (JSQLParserException e) {
|
} catch (JSQLParserException e) {
|
||||||
@@ -211,8 +211,6 @@ public class SqlRemoveHelper {
|
|||||||
|
|
||||||
public static String removeHavingCondition(String sql, Set<String> removeFieldNames) {
|
public static String removeHavingCondition(String sql, Set<String> removeFieldNames) {
|
||||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
//SelectBody selectBody = selectStatement.getSelectBody();
|
|
||||||
|
|
||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import net.sf.jsqlparser.expression.Expression;
|
|||||||
import net.sf.jsqlparser.expression.Function;
|
import net.sf.jsqlparser.expression.Function;
|
||||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
|
import net.sf.jsqlparser.statement.select.AllColumns;
|
||||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||||
import net.sf.jsqlparser.statement.select.Select;
|
import net.sf.jsqlparser.statement.select.Select;
|
||||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||||
@@ -45,8 +46,6 @@ public class SqlSelectFunctionHelper {
|
|||||||
|
|
||||||
public static Set<String> getFunctions(String sql) {
|
public static Set<String> getFunctions(String sql) {
|
||||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
//SelectBody selectBody = selectStatement.getSelectBody();
|
|
||||||
|
|
||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
@@ -106,5 +105,18 @@ public class SqlSelectFunctionHelper {
|
|||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean hasAsterisk(String sql) {
|
||||||
|
List<PlainSelect> plainSelectList = SqlSelectHelper.getPlainSelect(sql);
|
||||||
|
if (CollectionUtils.isEmpty(plainSelectList)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (PlainSelect plainSelect : plainSelectList) {
|
||||||
|
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
|
||||||
|
if (selectItems.stream().anyMatch(item -> item.getExpression() instanceof AllColumns)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.common.jsqlparser;
|
|||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
@@ -12,6 +11,20 @@ import org.junit.jupiter.api.Test;
|
|||||||
*/
|
*/
|
||||||
class SqlRemoveHelperTest {
|
class SqlRemoveHelperTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testRemoveAsterisk() {
|
||||||
|
String sql = "select * from 歌曲库";
|
||||||
|
Set<String> fields = new HashSet<>();
|
||||||
|
fields.add("歌曲名");
|
||||||
|
fields.add("性别");
|
||||||
|
sql = SqlRemoveHelper.removeAsteriskAndAddFields(sql, fields);
|
||||||
|
Assert.assertEquals(sql, "SELECT 歌曲名, 性别 FROM 歌曲库");
|
||||||
|
|
||||||
|
sql = "select 歌曲名 from 歌曲库";
|
||||||
|
sql = SqlRemoveHelper.removeAsteriskAndAddFields(sql, fields);
|
||||||
|
Assert.assertEquals(sql, "SELECT 歌曲名 FROM 歌曲库");
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testRemoveSameFieldFromSelect() {
|
void testRemoveSameFieldFromSelect() {
|
||||||
String sql = "select 歌曲名,歌手名,粉丝数,粉丝数,sum(粉丝数),sum(粉丝数),avg(播放量),avg(播放量)"
|
String sql = "select 歌曲名,歌手名,粉丝数,粉丝数,sum(粉丝数),sum(粉丝数),avg(播放量),avg(播放量)"
|
||||||
|
|||||||
@@ -72,4 +72,15 @@ class SqlSelectFunctionHelperTest {
|
|||||||
|
|
||||||
Assert.assertEquals(hasFunction, true);
|
Assert.assertEquals(hasFunction, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testHasAsterisk() {
|
||||||
|
String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' "
|
||||||
|
+ "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1";
|
||||||
|
|
||||||
|
Assert.assertEquals(SqlSelectFunctionHelper.hasAsterisk(sql), false);
|
||||||
|
sql = "select * from 超音数 where 数据日期 = '2023-08-08' "
|
||||||
|
+ "and 用户 =alice and 发布日期 ='11'";
|
||||||
|
Assert.assertEquals(SqlSelectFunctionHelper.hasAsterisk(sql), true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.headless.api.pojo;
|
||||||
|
|
||||||
import lombok.Data;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DataSetSchema {
|
public class DataSetSchema {
|
||||||
|
|
||||||
private SchemaElement dataSet;
|
private SchemaElement dataSet;
|
||||||
private Set<SchemaElement> metrics = new HashSet<>();
|
private Set<SchemaElement> metrics = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
@@ -78,4 +84,32 @@ public class DataSetSchema {
|
|||||||
return queryConfig.getTagTypeDefaultConfig();
|
return queryConfig.getTagTypeDefaultConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getTagDefaultDimensions() {
|
||||||
|
TagTypeDefaultConfig tagTypeDefaultConfig = getTagTypeDefaultConfig();
|
||||||
|
if (Objects.isNull(tagTypeDefaultConfig) || Objects.isNull(tagTypeDefaultConfig.getDefaultDisplayInfo())) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
|
||||||
|
return tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds()
|
||||||
|
.stream().map(id -> {
|
||||||
|
SchemaElement metric = getElement(SchemaElementType.METRIC, id);
|
||||||
|
return metric;
|
||||||
|
}).filter(Objects::nonNull).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getTagDefaultMetrics() {
|
||||||
|
TagTypeDefaultConfig tagTypeDefaultConfig = getTagTypeDefaultConfig();
|
||||||
|
if (Objects.isNull(tagTypeDefaultConfig) || Objects.isNull(tagTypeDefaultConfig.getDefaultDisplayInfo())) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
|
||||||
|
return tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
|
||||||
|
.map(id -> getElement(SchemaElementType.DIMENSION, id))
|
||||||
|
.filter(Objects::nonNull).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.chat.corrector;
|
package com.tencent.supersonic.headless.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.core.env.Environment;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -22,6 +14,10 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* basic semantic correction functionality, offering common methods and an
|
* basic semantic correction functionality, offering common methods and an
|
||||||
@@ -75,27 +71,6 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected String addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
|
||||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
|
||||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
|
||||||
|
|
||||||
//decide whether add order by expression field to select
|
|
||||||
Environment environment = ContextUtils.getBean(Environment.class);
|
|
||||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
|
||||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
|
||||||
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
|
||||||
return correctS2SQL;
|
|
||||||
}
|
|
||||||
|
|
||||||
needAddFields.removeAll(selectFields);
|
|
||||||
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
|
|
||||||
return addFieldsToSelectSql;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
|||||||
@@ -1,11 +1,24 @@
|
|||||||
package com.tencent.supersonic.headless.chat.corrector;
|
package com.tencent.supersonic.headless.chat.corrector;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.core.env.Environment;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -27,8 +40,71 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
|||||||
&& aggregateFields.size() == selectFields.size()) {
|
&& aggregateFields.size() == selectFields.size()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
correctS2SQL = addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
correctS2SQL = addFieldsToSelect(queryContext, semanticParseInfo, correctS2SQL);
|
||||||
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
|
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected String addFieldsToSelect(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
|
||||||
|
String correctS2SQL) {
|
||||||
|
correctS2SQL = addTagDefaultFields(queryContext, semanticParseInfo, correctS2SQL);
|
||||||
|
|
||||||
|
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||||
|
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||||
|
|
||||||
|
//decide whether add order by expression field to select
|
||||||
|
String correctorAdditionalInfo = getAdditionalInfo();
|
||||||
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
|
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||||
|
return correctS2SQL;
|
||||||
|
}
|
||||||
|
needAddFields.removeAll(selectFields);
|
||||||
|
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
|
||||||
|
return addFieldsToSelectSql;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String addTagDefaultFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
|
||||||
|
String correctS2SQL) {
|
||||||
|
//If it is in DETAIL mode and select *, add default metrics and dimensions.
|
||||||
|
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
|
||||||
|
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
|
||||||
|
return correctS2SQL;
|
||||||
|
}
|
||||||
|
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||||
|
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||||
|
Set<String> needAddDefaultFields = new HashSet<>();
|
||||||
|
if (Objects.nonNull(dataSetSchema)) {
|
||||||
|
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
|
||||||
|
Set<String> metrics = dataSetSchema.getTagDefaultMetrics()
|
||||||
|
.stream().map(schemaElement -> schemaElement.getName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
needAddDefaultFields.addAll(metrics);
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultDimensions())) {
|
||||||
|
Set<String> dimensions = dataSetSchema.getTagDefaultDimensions()
|
||||||
|
.stream().map(schemaElement -> schemaElement.getName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
needAddDefaultFields.addAll(dimensions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// remove * in sql and add default fields.
|
||||||
|
if (!CollectionUtils.isEmpty(needAddDefaultFields)) {
|
||||||
|
correctS2SQL = SqlRemoveHelper.removeAsteriskAndAddFields(correctS2SQL, needAddDefaultFields);
|
||||||
|
}
|
||||||
|
return correctS2SQL;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getAdditionalInfo() {
|
||||||
|
String correctorAdditionalInfo = null;
|
||||||
|
try {
|
||||||
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
|
correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("getAdditionalInfo error:{}", e);
|
||||||
|
}
|
||||||
|
return correctorAdditionalInfo;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,102 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||||
|
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
class SelectCorrectorTest {
|
||||||
|
|
||||||
|
Long dataSetId = 2L;
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testDoCorrect() {
|
||||||
|
BaseSemanticCorrector corrector = new SelectCorrector();
|
||||||
|
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||||
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
|
SchemaElement dataSet = new SchemaElement();
|
||||||
|
dataSet.setDataSet(dataSetId);
|
||||||
|
semanticParseInfo.setDataSet(dataSet);
|
||||||
|
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||||
|
SqlInfo sqlInfo = new SqlInfo();
|
||||||
|
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
|
||||||
|
sqlInfo.setS2SQL(sql);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||||
|
corrector.correct(queryContext, semanticParseInfo);
|
||||||
|
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
|
||||||
|
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
|
}
|
||||||
|
|
||||||
|
private QueryContext buildQueryContext(Long dataSetId) {
|
||||||
|
QueryContext queryContext = new QueryContext();
|
||||||
|
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||||
|
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||||
|
QueryConfig queryConfig = new QueryConfig();
|
||||||
|
TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig();
|
||||||
|
DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo();
|
||||||
|
List<Long> dimensionIds = new ArrayList<>();
|
||||||
|
dimensionIds.add(1L);
|
||||||
|
dimensionIds.add(2L);
|
||||||
|
dimensionIds.add(3L);
|
||||||
|
defaultDisplayInfo.setDimensionIds(dimensionIds);
|
||||||
|
|
||||||
|
List<Long> metricIds = new ArrayList<>();
|
||||||
|
metricIds.add(4L);
|
||||||
|
defaultDisplayInfo.setMetricIds(metricIds);
|
||||||
|
|
||||||
|
tagTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo);
|
||||||
|
queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig);
|
||||||
|
|
||||||
|
dataSetSchema.setQueryConfig(queryConfig);
|
||||||
|
SchemaElement schemaElement = new SchemaElement();
|
||||||
|
schemaElement.setDataSet(dataSetId);
|
||||||
|
dataSetSchema.setDataSet(schemaElement);
|
||||||
|
Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
|
SchemaElement element1 = new SchemaElement();
|
||||||
|
element1.setDataSet(dataSetId);
|
||||||
|
element1.setId(1L);
|
||||||
|
element1.setName("艺人名");
|
||||||
|
dimensions.add(element1);
|
||||||
|
|
||||||
|
SchemaElement element2 = new SchemaElement();
|
||||||
|
element2.setDataSet(dataSetId);
|
||||||
|
element2.setId(2L);
|
||||||
|
element2.setName("性别");
|
||||||
|
dimensions.add(element2);
|
||||||
|
|
||||||
|
SchemaElement element3 = new SchemaElement();
|
||||||
|
element3.setDataSet(dataSetId);
|
||||||
|
element3.setId(3L);
|
||||||
|
element3.setName("国籍");
|
||||||
|
dimensions.add(element3);
|
||||||
|
|
||||||
|
dataSetSchema.setDimensions(dimensions);
|
||||||
|
|
||||||
|
Set<SchemaElement> metrics = new HashSet<>();
|
||||||
|
SchemaElement metric1 = new SchemaElement();
|
||||||
|
metric1.setDataSet(dataSetId);
|
||||||
|
metric1.setId(4L);
|
||||||
|
metric1.setName("粉丝数");
|
||||||
|
metrics.add(metric1);
|
||||||
|
|
||||||
|
dataSetSchema.setMetrics(metrics);
|
||||||
|
dataSetSchemaList.add(dataSetSchema);
|
||||||
|
|
||||||
|
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||||
|
queryContext.setSemanticSchema(semanticSchema);
|
||||||
|
return queryContext;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user