(improvement)(chat) logic sql show in chinese and convert to bizName in execute (#156)

This commit is contained in:
lexluo09
2023-09-27 17:27:31 +08:00
committed by GitHub
parent f931951ad5
commit 617db611c3
17 changed files with 138 additions and 100 deletions

View File

@@ -6,10 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@@ -23,14 +23,11 @@ import org.springframework.util.CollectionUtils;
@Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector {
public static final String DATE_FIELD = "数据日期";
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
}
protected Map<String, String> getFieldToBizName(Long modelId) {
protected Map<String, String> getFieldNameMap(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
@@ -40,8 +37,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
Map<String, String> result = dbAllFields.stream()
.filter(entry -> entry.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName());
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getName(), (k1, k2) -> k1));
result.put(DateUtils.DATE_FIELD, DateUtils.DATE_FIELD);
return result;
}
@@ -55,9 +52,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
whereFields.remove(DateUtils.DATE_FIELD);
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
semanticCorrectInfo.setSql(replaceFields);
}
@@ -75,7 +70,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
}
return schemaElement;
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
}).collect(Collectors.toMap(a -> a.getName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;

View File

@@ -27,7 +27,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
updateFieldNameByLinkingValue(semanticCorrectInfo);
updateFieldNameByBizName(semanticCorrectInfo);
correctFieldName(semanticCorrectInfo);
addAggregateToMetric(semanticCorrectInfo);
}
@@ -39,11 +39,11 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
semanticCorrectInfo.setSql(replaceAlias);
}
private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) {
private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) {
Map<String, String> fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId());
Map<String, String> fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId());
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName);
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap);
semanticCorrectInfo.setSql(sql);
}

View File

@@ -3,10 +3,10 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@@ -28,8 +28,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
.filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName()))
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
.filter(schemaElement -> !DateUtils.DATE_FIELD.equals(schemaElement.getBizName()))
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);

View File

@@ -25,7 +25,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet());
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) {
return;

View File

@@ -8,11 +8,11 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -70,9 +70,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) {
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate);
sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
}
semanticCorrectInfo.setSql(sql);
}
@@ -83,7 +83,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
@@ -117,11 +117,11 @@ public class WhereCorrector extends BaseSemanticCorrector {
for (SchemaElement dimension : dimensions) {
if (Objects.isNull(dimension)
|| Strings.isEmpty(dimension.getBizName())
|| Strings.isEmpty(dimension.getName())
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
continue;
}
String bizName = dimension.getBizName();
String name = dimension.getName();
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
@@ -141,7 +141,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
}
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
result.put(bizName, aliasAndBizNameToTechName);
result.put(name, aliasAndBizNameToTechName);
}
}
return result;

View File

@@ -16,7 +16,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.corrector.BaseSemanticCorrector;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver;
import com.tencent.supersonic.chat.query.QueryManager;
@@ -31,11 +30,11 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.ArrayList;
import java.util.Arrays;
@@ -113,7 +112,7 @@ public class LLMDslParser implements SemanticParser {
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getBizName())
&& allFields.contains(schemaElement.getName())
).collect(Collectors.toSet());
}
@@ -122,7 +121,7 @@ public class LLMDslParser implements SemanticParser {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.getNameList().contains(entry))
.filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
@@ -130,6 +129,7 @@ public class LLMDslParser implements SemanticParser {
String correctorSql = semanticCorrectInfo.getSql();
parseInfo.getSqlInfo().setLogicSql(correctorSql);
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
//set dataInfo
try {
@@ -143,8 +143,8 @@ public class LLMDslParser implements SemanticParser {
//set filter
try {
Map<String, SchemaElement> bizNameToElement = getBizNameToElement(modelId);
List<QueryFilter> result = getDimensionFilter(bizNameToElement, expressions);
Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelId);
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
@@ -173,20 +173,18 @@ public class LLMDslParser implements SemanticParser {
}
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement,
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
String bizName = expression.getFieldName();
SchemaElement schemaElement = bizNameToElement.get(bizName);
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
String fieldName = schemaElement.getName();
dimensionFilter.setName(fieldName);
dimensionFilter.setBizName(bizName);
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
@@ -198,13 +196,8 @@ public class LLMDslParser implements SemanticParser {
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> {
List<String> nameList = TimeDimensionEnum.getNameList();
if (StringUtils.isEmpty(expression.getFieldName())) {
return false;
}
return nameList.contains(expression.getFieldName().toLowerCase());
}).collect(Collectors.toList());
.filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
}
@@ -354,7 +347,7 @@ public class LLMDslParser implements SemanticParser {
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
fieldNameList.add(DateUtils.DATE_FIELD);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
@@ -391,7 +384,7 @@ public class LLMDslParser implements SemanticParser {
}
protected Map<String, SchemaElement> getBizNameToElement(Long modelId) {
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
@@ -401,7 +394,7 @@ public class LLMDslParser implements SemanticParser {
allElements.addAll(metrics);
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getBizName, Function.identity(), (value1, value2) -> value2));
.collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2));
}

View File

@@ -68,7 +68,7 @@ class LLMDslParserTest {
model.setId(2L);
parseInfo.setModel(model);
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ")
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 ")
.parseInfo(parseInfo)
.build();

View File

@@ -16,6 +16,7 @@ public class DateUtils {
public static final String DATE_FORMAT = "yyyy-MM-dd";
public static final String DATE_FIELD = "数据日期";
public static final String TIME_FORMAT = "yyyy-MM-dd HH:mm:ss";
public static Integer currentYear() {

View File

@@ -9,16 +9,16 @@ import net.sf.jsqlparser.schema.Column;
public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName;
private Map<String, String> fieldNameMap;
private boolean exactReplace;
public FieldReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) {
this.fieldToBizName = fieldToBizName;
public FieldReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldNameMap = fieldNameMap;
this.exactReplace = exactReplace;
}
@Override
public void visit(Column column) {
parseVisitorHelper.replaceColumn(column, fieldToBizName, exactReplace);
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace);
}
}

View File

@@ -17,11 +17,11 @@ import org.apache.commons.lang3.StringUtils;
public class GroupByReplaceVisitor implements GroupByVisitor {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName;
private Map<String, String> fieldNameMap;
private boolean exactReplace;
public GroupByReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) {
this.fieldToBizName = fieldToBizName;
public GroupByReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldNameMap = fieldNameMap;
this.exactReplace = exactReplace;
}
@@ -33,7 +33,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
for (int i = 0; i < groupByExpressions.size(); i++) {
Expression expression = groupByExpressions.get(i);
String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName,
String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldNameMap,
exactReplace);
if (StringUtils.isNotEmpty(replaceColumn)) {
if (expression instanceof Column) {

View File

@@ -11,11 +11,11 @@ import net.sf.jsqlparser.statement.select.OrderByVisitorAdapter;
public class OrderByReplaceVisitor extends OrderByVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName;
private Map<String, String> fieldNameMap;
private boolean exactReplace;
public OrderByReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) {
this.fieldToBizName = fieldToBizName;
public OrderByReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldNameMap = fieldNameMap;
this.exactReplace = exactReplace;
}
@@ -23,14 +23,14 @@ public class OrderByReplaceVisitor extends OrderByVisitorAdapter {
public void visit(OrderByElement orderBy) {
Expression expression = orderBy.getExpression();
if (expression instanceof Column) {
parseVisitorHelper.replaceColumn((Column) expression, fieldToBizName, exactReplace);
parseVisitorHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace);
}
if (expression instanceof Function) {
Function function = (Function) expression;
List<Expression> expressions = function.getParameters().getExpressions();
for (Expression column : expressions) {
if (column instanceof Column) {
parseVisitorHelper.replaceColumn((Column) column, fieldToBizName, exactReplace);
parseVisitorHelper.replaceColumn((Column) column, fieldNameMap, exactReplace);
}
}
}

View File

@@ -11,23 +11,23 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j
public class ParseVisitorHelper {
public void replaceColumn(Column column, Map<String, String> fieldToBizName, boolean exactReplace) {
public void replaceColumn(Column column, Map<String, String> fieldNameMap, boolean exactReplace) {
String columnName = column.getColumnName();
String replaceColumn = getReplaceColumn(columnName, fieldToBizName, exactReplace);
String replaceColumn = getReplaceColumn(columnName, fieldNameMap, exactReplace);
if (StringUtils.isNotBlank(replaceColumn)) {
column.setColumnName(replaceColumn);
}
}
public String getReplaceColumn(String columnName, Map<String, String> fieldToBizName, boolean exactReplace) {
String fieldBizName = fieldToBizName.get(columnName);
if (StringUtils.isNotBlank(fieldBizName)) {
return fieldBizName;
public String getReplaceColumn(String columnName, Map<String, String> fieldNameMap, boolean exactReplace) {
String fieldName = fieldNameMap.get(columnName);
if (StringUtils.isNotBlank(fieldName)) {
return fieldName;
}
if (exactReplace) {
return null;
}
Optional<Entry<String, String>> first = fieldToBizName.entrySet().stream().sorted((k1, k2) -> {
Optional<Entry<String, String>> first = fieldNameMap.entrySet().stream().sorted((k1, k2) -> {
String k1FieldNameDb = k1.getKey();
String k2FieldNameDb = k2.getKey();
Double k1Similarity = getSimilarity(columnName, k1FieldNameDb);

View File

@@ -65,11 +65,11 @@ public class SqlParserUpdateHelper {
return selectStatement.toString();
}
public static String replaceFields(String sql, Map<String, String> fieldToBizName) {
return replaceFields(sql, fieldToBizName, false);
public static String replaceFields(String sql, Map<String, String> fieldNameMap) {
return replaceFields(sql, fieldNameMap, false);
}
public static String replaceFields(String sql, Map<String, String> fieldToBizName, boolean exactReplace) {
public static String replaceFields(String sql, Map<String, String> fieldNameMap, boolean exactReplace) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
@@ -78,7 +78,7 @@ public class SqlParserUpdateHelper {
PlainSelect plainSelect = (PlainSelect) selectBody;
//1. replace where fields
Expression where = plainSelect.getWhere();
FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldToBizName, exactReplace);
FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldNameMap, exactReplace);
if (Objects.nonNull(where)) {
where.accept(visitor);
}
@@ -92,14 +92,14 @@ public class SqlParserUpdateHelper {
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
if (!CollectionUtils.isEmpty(orderByElements)) {
for (OrderByElement orderByElement : orderByElements) {
orderByElement.accept(new OrderByReplaceVisitor(fieldToBizName, exactReplace));
orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace));
}
}
//4. replace group by fields
GroupByElement groupByElement = plainSelect.getGroupBy();
if (Objects.nonNull(groupByElement)) {
groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName, exactReplace));
groupByElement.accept(new GroupByReplaceVisitor(fieldNameMap, exactReplace));
}
//5. replace having fields
Expression having = plainSelect.getHaving();

View File

@@ -36,5 +36,4 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector

View File

@@ -36,5 +36,4 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector

View File

@@ -1,6 +1,10 @@
package com.tencent.supersonic.semantic.query.parser.convert;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.request.SqlExecuteReq;
import com.tencent.supersonic.semantic.api.model.response.DatabaseResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
@@ -17,6 +21,7 @@ import com.tencent.supersonic.semantic.query.utils.QueryStructUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -31,6 +36,7 @@ import org.springframework.util.CollectionUtils;
@Slf4j
public class QueryReqConverter {
public static final String TABLE_PREFIX = "t_";
@Autowired
private ModelService domainService;
@Autowired
@@ -41,38 +47,36 @@ public class QueryReqConverter {
@Autowired
private Catalog catalog;
public QueryStatement convert(QueryDslReq databaseReq, List<ModelSchemaResp> domainSchemas) throws Exception {
public QueryStatement convert(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) throws Exception {
List<MetricTable> tables = new ArrayList<>();
MetricTable metricTable = new MetricTable();
List<String> allFields = SqlParserSelectHelper.getAllFields(databaseReq.getSql());
if (Objects.isNull(modelSchemaResp)) {
return new QueryStatement();
}
//1.convert name to bizName
convertNameToBizName(databaseReq, modelSchemaResp);
//2.functionName corrector
functionNameCorrector(databaseReq);
//3.correct tableName
correctTableName(databaseReq);
String tableName = SqlParserSelectHelper.getTableName(databaseReq.getSql());
functionNameCorrector(databaseReq);
if (CollectionUtils.isEmpty(domainSchemas) || StringUtils.isEmpty(tableName)) {
if (StringUtils.isEmpty(tableName)) {
return new QueryStatement();
}
Set<String> dimensions = domainSchemas.get(0).getDimensions().stream()
.map(entry -> entry.getBizName().toLowerCase())
.collect(Collectors.toSet());
dimensions.addAll(QueryStructUtils.internalCols);
List<String> allFields = SqlParserSelectHelper.getAllFields(databaseReq.getSql());
Set<String> metrics = domainSchemas.get(0).getMetrics().stream().map(entry -> entry.getBizName().toLowerCase())
.collect(Collectors.toSet());
List<String> metrics = getMetrics(modelSchemaResp, allFields);
metricTable.setMetrics(metrics);
Set<String> dimensions = getDimensions(modelSchemaResp, allFields);
metricTable.setDimensions(new ArrayList<>(dimensions));
metricTable.setMetrics(allFields.stream().filter(entry -> metrics.contains(entry.toLowerCase()))
.map(String::toLowerCase).collect(Collectors.toList()));
Set<String> collect = allFields.stream().filter(entry -> dimensions.contains(entry.toLowerCase()))
.map(String::toLowerCase).collect(Collectors.toSet());
for (String internalCol : QueryStructUtils.internalCols) {
if (databaseReq.getSql().contains(internalCol)) {
collect.add(internalCol);
}
}
metricTable.setDimensions(new ArrayList<>(collect));
metricTable.setAlias(tableName.toLowerCase());
// if metric empty , fill model default
if (CollectionUtils.isEmpty(metricTable.getMetrics())) {
@@ -92,6 +96,33 @@ public class QueryReqConverter {
return queryStatement;
}
private void convertNameToBizName(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) {
Map<String, String> fieldNameToBizNameMap = getFieldNameToBizNameMap(modelSchemaResp);
String sql = databaseReq.getSql();
log.info("convert name to bizName before:{}", sql);
String replaceFields = SqlParserUpdateHelper.replaceFields(sql, fieldNameToBizNameMap, true);
log.info("convert name to bizName after:{}", replaceFields);
databaseReq.setSql(replaceFields);
}
private Set<String> getDimensions(ModelSchemaResp modelSchemaResp, List<String> allFields) {
Set<String> allDimensions = modelSchemaResp.getDimensions().stream()
.map(entry -> entry.getBizName().toLowerCase())
.collect(Collectors.toSet());
allDimensions.addAll(QueryStructUtils.internalCols);
Set<String> collect = allFields.stream().filter(entry -> allDimensions.contains(entry.toLowerCase()))
.map(String::toLowerCase).collect(Collectors.toSet());
return collect;
}
private List<String> getMetrics(ModelSchemaResp modelSchemaResp, List<String> allFields) {
Set<String> allMetrics = modelSchemaResp.getMetrics().stream().map(entry -> entry.getBizName().toLowerCase())
.collect(Collectors.toSet());
List<String> metrics = allFields.stream().filter(entry -> allMetrics.contains(entry.toLowerCase()))
.map(String::toLowerCase).collect(Collectors.toList());
return metrics;
}
private void functionNameCorrector(QueryDslReq databaseReq) {
DatabaseResp database = catalog.getDatabaseByModelId(databaseReq.getModelId());
if (Objects.isNull(database) || Objects.isNull(database.getType())) {
@@ -107,4 +138,21 @@ public class QueryReqConverter {
}
}
protected Map<String, String> getFieldNameToBizNameMap(ModelSchemaResp modelSchemaResp) {
List<SchemaItem> allSchemaItems = new ArrayList<>();
allSchemaItems.addAll(modelSchemaResp.getDimensions());
allSchemaItems.addAll(modelSchemaResp.getMetrics());
Map<String, String> result = allSchemaItems.stream()
.collect(Collectors.toMap(SchemaItem::getName, a -> a.getBizName(), (k1, k2) -> k1));
result.put(DateUtils.DATE_FIELD, TimeDimensionEnum.DAY.getName());
return result;
}
public void correctTableName(QueryDslReq databaseReq) {
String sql = SqlParserUpdateHelper.replaceTable(databaseReq.getSql(), TABLE_PREFIX + databaseReq.getModelId());
databaseReq.setSql(sql);
}
}

View File

@@ -90,8 +90,11 @@ public class QueryServiceImpl implements QueryService {
filter.setModelIds(modelIds);
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
List<ModelSchemaResp> domainSchemas = schemaService.fetchModelSchema(filter, user);
QueryStatement queryStatement = queryReqConverter.convert(querySqlCmd, domainSchemas);
ModelSchemaResp domainSchema = null;
if (CollectionUtils.isNotEmpty(domainSchemas)) {
domainSchema = domainSchemas.get(0);
}
QueryStatement queryStatement = queryReqConverter.convert(querySqlCmd, domainSchema);
queryStatement.setModelId(querySqlCmd.getModelId());
return queryStatement;
}