diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index 3665a5d04..ace7f3d6c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -6,10 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; -import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -23,14 +23,11 @@ import org.springframework.util.CollectionUtils; @Slf4j public abstract class BaseSemanticCorrector implements SemanticCorrector { - public static final String DATE_FIELD = "数据日期"; - - public void correct(SemanticCorrectInfo semanticCorrectInfo) { semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql()); } - protected Map getFieldToBizName(Long modelId) { + protected Map getFieldNameMap(Long modelId) { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); @@ -40,8 +37,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { Map result = dbAllFields.stream() .filter(entry -> entry.getModel().equals(modelId)) - .collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1)); - result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName()); + .collect(Collectors.toMap(SchemaElement::getName, a -> a.getName(), (k1, k2) -> k1)); + result.put(DateUtils.DATE_FIELD, DateUtils.DATE_FIELD); return result; } @@ -55,9 +52,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); whereFields.removeAll(selectFields); - whereFields.remove(TimeDimensionEnum.DAY.getName()); - whereFields.remove(TimeDimensionEnum.WEEK.getName()); - whereFields.remove(TimeDimensionEnum.MONTH.getName()); + whereFields.remove(DateUtils.DATE_FIELD); String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); semanticCorrectInfo.setSql(replaceFields); } @@ -75,7 +70,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name()); } return schemaElement; - }).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1)); + }).collect(Collectors.toMap(a -> a.getName(), a -> a.getDefaultAgg(), (k1, k2) -> k1)); if (CollectionUtils.isEmpty(metricToAggregate)) { return; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java index 82a1f4f02..57d46c720 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java @@ -27,7 +27,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { updateFieldNameByLinkingValue(semanticCorrectInfo); - updateFieldNameByBizName(semanticCorrectInfo); + correctFieldName(semanticCorrectInfo); addAggregateToMetric(semanticCorrectInfo); } @@ -39,11 +39,11 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { semanticCorrectInfo.setSql(replaceAlias); } - private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) { + private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) { - Map fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()); + Map fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId()); - String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName); + String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap); semanticCorrectInfo.setSql(sql); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java index 25018fef4..5fe69d5f0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -3,10 +3,10 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; -import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -28,8 +28,8 @@ public class GroupByCorrector extends BaseSemanticCorrector { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); Set dimensions = semanticSchema.getDimensions(modelId).stream() - .filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName())) - .map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet()); + .filter(schemaElement -> !DateUtils.DATE_FIELD.equals(schemaElement.getBizName())) + .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet()); List selectFields = SqlParserSelectHelper.getSelectFields(sql); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java index d1dc40d0b..736a056f6 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -25,7 +25,7 @@ public class HavingCorrector extends BaseSemanticCorrector { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); Set metrics = semanticSchema.getMetrics(modelId).stream() - .map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet()); + .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet()); if (CollectionUtils.isEmpty(metrics)) { return; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index dd16fa462..06dc51a5f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -8,11 +8,11 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; -import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -70,9 +70,9 @@ public class WhereCorrector extends BaseSemanticCorrector { private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) { String sql = semanticCorrectInfo.getSql(); List whereFields = SqlParserSelectHelper.getWhereFields(sql); - if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) { + if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) { String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); - sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate); + sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate); } semanticCorrectInfo.setSql(sql); } @@ -83,7 +83,7 @@ public class WhereCorrector extends BaseSemanticCorrector { } return queryFilters.getFilters().stream() .map(filter -> { - String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName()); + String bizNameWrap = StringUtil.getSpaceWrap(filter.getName()); String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); return bizNameWrap + operatorWrap + valueWrap; @@ -117,11 +117,11 @@ public class WhereCorrector extends BaseSemanticCorrector { for (SchemaElement dimension : dimensions) { if (Objects.isNull(dimension) - || Strings.isEmpty(dimension.getBizName()) + || Strings.isEmpty(dimension.getName()) || CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) { continue; } - String bizName = dimension.getBizName(); + String name = dimension.getName(); Map aliasAndBizNameToTechName = new HashMap<>(); @@ -141,7 +141,7 @@ public class WhereCorrector extends BaseSemanticCorrector { } } if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) { - result.put(bizName, aliasAndBizNameToTechName); + result.put(name, aliasAndBizNameToTechName); } } return result; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java index fd66fe767..fa8d46a5a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java @@ -16,7 +16,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.config.LLMParserConfig; -import com.tencent.supersonic.chat.corrector.BaseSemanticCorrector; import com.tencent.supersonic.chat.parser.SatisfactionChecker; import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver; import com.tencent.supersonic.chat.query.QueryManager; @@ -31,11 +30,11 @@ import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; -import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import java.util.ArrayList; import java.util.Arrays; @@ -113,7 +112,7 @@ public class LLMDslParser implements SemanticParser { private Set getElements(Long modelId, List allFields, List elements) { return elements.stream() .filter(schemaElement -> modelId.equals(schemaElement.getModel()) - && allFields.contains(schemaElement.getBizName()) + && allFields.contains(schemaElement.getName()) ).collect(Collectors.toSet()); } @@ -122,7 +121,7 @@ public class LLMDslParser implements SemanticParser { return new ArrayList<>(); } return allFields.stream() - .filter(entry -> !TimeDimensionEnum.getNameList().contains(entry)) + .filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry)) .collect(Collectors.toList()); } @@ -130,6 +129,7 @@ public class LLMDslParser implements SemanticParser { String correctorSql = semanticCorrectInfo.getSql(); parseInfo.getSqlInfo().setLogicSql(correctorSql); + List expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); //set dataInfo try { @@ -143,8 +143,8 @@ public class LLMDslParser implements SemanticParser { //set filter try { - Map bizNameToElement = getBizNameToElement(modelId); - List result = getDimensionFilter(bizNameToElement, expressions); + Map fieldNameToElement = getNameToElement(modelId); + List result = getDimensionFilter(fieldNameToElement, expressions); parseInfo.getDimensionFilters().addAll(result); } catch (Exception e) { log.error("set dimensionFilter error :", e); @@ -173,20 +173,18 @@ public class LLMDslParser implements SemanticParser { } } - private List getDimensionFilter(Map bizNameToElement, + private List getDimensionFilter(Map fieldNameToElement, List filterExpressions) { List result = Lists.newArrayList(); for (FilterExpression expression : filterExpressions) { QueryFilter dimensionFilter = new QueryFilter(); dimensionFilter.setValue(expression.getFieldValue()); - String bizName = expression.getFieldName(); - SchemaElement schemaElement = bizNameToElement.get(bizName); + SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName()); if (Objects.isNull(schemaElement)) { continue; } - String fieldName = schemaElement.getName(); - dimensionFilter.setName(fieldName); - dimensionFilter.setBizName(bizName); + dimensionFilter.setName(schemaElement.getName()); + dimensionFilter.setBizName(schemaElement.getBizName()); dimensionFilter.setElementID(schemaElement.getId()); FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator()); @@ -198,13 +196,8 @@ public class LLMDslParser implements SemanticParser { private DateConf getDateInfo(List filterExpressions) { List dateExpressions = filterExpressions.stream() - .filter(expression -> { - List nameList = TimeDimensionEnum.getNameList(); - if (StringUtils.isEmpty(expression.getFieldName())) { - return false; - } - return nameList.contains(expression.getFieldName().toLowerCase()); - }).collect(Collectors.toList()); + .filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName())) + .collect(Collectors.toList()); if (CollectionUtils.isEmpty(dateExpressions)) { return new DateConf(); } @@ -354,7 +347,7 @@ public class LLMDslParser implements SemanticParser { List fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig); - fieldNameList.add(BaseSemanticCorrector.DATE_FIELD); + fieldNameList.add(DateUtils.DATE_FIELD); llmSchema.setFieldNameList(fieldNameList); llmReq.setSchema(llmSchema); @@ -391,7 +384,7 @@ public class LLMDslParser implements SemanticParser { } - protected Map getBizNameToElement(Long modelId) { + protected Map getNameToElement(Long modelId) { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); List dimensions = semanticSchema.getDimensions(); List metrics = semanticSchema.getMetrics(); @@ -401,7 +394,7 @@ public class LLMDslParser implements SemanticParser { allElements.addAll(metrics); return allElements.stream() .filter(schemaElement -> schemaElement.getModel().equals(modelId)) - .collect(Collectors.toMap(SchemaElement::getBizName, Function.identity(), (value1, value2) -> value2)); + .collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2)); } diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParserTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParserTest.java index 2cd01f0be..7c233c448 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParserTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParserTest.java @@ -68,7 +68,7 @@ class LLMDslParserTest { model.setId(2L); parseInfo.setModel(model); SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ") + .sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 ") .parseInfo(parseInfo) .build(); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java index dbf0ffe0c..ea6363c65 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java @@ -16,6 +16,7 @@ public class DateUtils { public static final String DATE_FORMAT = "yyyy-MM-dd"; + public static final String DATE_FIELD = "数据日期"; public static final String TIME_FORMAT = "yyyy-MM-dd HH:mm:ss"; public static Integer currentYear() { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java index 870cce668..f75e61324 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java @@ -9,16 +9,16 @@ import net.sf.jsqlparser.schema.Column; public class FieldReplaceVisitor extends ExpressionVisitorAdapter { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); - private Map fieldToBizName; + private Map fieldNameMap; private boolean exactReplace; - public FieldReplaceVisitor(Map fieldToBizName, boolean exactReplace) { - this.fieldToBizName = fieldToBizName; + public FieldReplaceVisitor(Map fieldNameMap, boolean exactReplace) { + this.fieldNameMap = fieldNameMap; this.exactReplace = exactReplace; } @Override public void visit(Column column) { - parseVisitorHelper.replaceColumn(column, fieldToBizName, exactReplace); + parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace); } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java index e60c465d9..6fb3d8f34 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java @@ -17,11 +17,11 @@ import org.apache.commons.lang3.StringUtils; public class GroupByReplaceVisitor implements GroupByVisitor { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); - private Map fieldToBizName; + private Map fieldNameMap; private boolean exactReplace; - public GroupByReplaceVisitor(Map fieldToBizName, boolean exactReplace) { - this.fieldToBizName = fieldToBizName; + public GroupByReplaceVisitor(Map fieldNameMap, boolean exactReplace) { + this.fieldNameMap = fieldNameMap; this.exactReplace = exactReplace; } @@ -33,7 +33,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor { for (int i = 0; i < groupByExpressions.size(); i++) { Expression expression = groupByExpressions.get(i); - String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName, + String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldNameMap, exactReplace); if (StringUtils.isNotEmpty(replaceColumn)) { if (expression instanceof Column) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByReplaceVisitor.java index 858400f86..95835d346 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByReplaceVisitor.java @@ -11,11 +11,11 @@ import net.sf.jsqlparser.statement.select.OrderByVisitorAdapter; public class OrderByReplaceVisitor extends OrderByVisitorAdapter { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); - private Map fieldToBizName; + private Map fieldNameMap; private boolean exactReplace; - public OrderByReplaceVisitor(Map fieldToBizName, boolean exactReplace) { - this.fieldToBizName = fieldToBizName; + public OrderByReplaceVisitor(Map fieldNameMap, boolean exactReplace) { + this.fieldNameMap = fieldNameMap; this.exactReplace = exactReplace; } @@ -23,14 +23,14 @@ public class OrderByReplaceVisitor extends OrderByVisitorAdapter { public void visit(OrderByElement orderBy) { Expression expression = orderBy.getExpression(); if (expression instanceof Column) { - parseVisitorHelper.replaceColumn((Column) expression, fieldToBizName, exactReplace); + parseVisitorHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace); } if (expression instanceof Function) { Function function = (Function) expression; List expressions = function.getParameters().getExpressions(); for (Expression column : expressions) { if (column instanceof Column) { - parseVisitorHelper.replaceColumn((Column) column, fieldToBizName, exactReplace); + parseVisitorHelper.replaceColumn((Column) column, fieldNameMap, exactReplace); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java index b526d6d24..f4a5c63e7 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java @@ -11,23 +11,23 @@ import org.apache.commons.lang3.StringUtils; @Slf4j public class ParseVisitorHelper { - public void replaceColumn(Column column, Map fieldToBizName, boolean exactReplace) { + public void replaceColumn(Column column, Map fieldNameMap, boolean exactReplace) { String columnName = column.getColumnName(); - String replaceColumn = getReplaceColumn(columnName, fieldToBizName, exactReplace); + String replaceColumn = getReplaceColumn(columnName, fieldNameMap, exactReplace); if (StringUtils.isNotBlank(replaceColumn)) { column.setColumnName(replaceColumn); } } - public String getReplaceColumn(String columnName, Map fieldToBizName, boolean exactReplace) { - String fieldBizName = fieldToBizName.get(columnName); - if (StringUtils.isNotBlank(fieldBizName)) { - return fieldBizName; + public String getReplaceColumn(String columnName, Map fieldNameMap, boolean exactReplace) { + String fieldName = fieldNameMap.get(columnName); + if (StringUtils.isNotBlank(fieldName)) { + return fieldName; } if (exactReplace) { return null; } - Optional> first = fieldToBizName.entrySet().stream().sorted((k1, k2) -> { + Optional> first = fieldNameMap.entrySet().stream().sorted((k1, k2) -> { String k1FieldNameDb = k1.getKey(); String k2FieldNameDb = k2.getKey(); Double k1Similarity = getSimilarity(columnName, k1FieldNameDb); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index 01fd38db8..8ed43dff5 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -65,11 +65,11 @@ public class SqlParserUpdateHelper { return selectStatement.toString(); } - public static String replaceFields(String sql, Map fieldToBizName) { - return replaceFields(sql, fieldToBizName, false); + public static String replaceFields(String sql, Map fieldNameMap) { + return replaceFields(sql, fieldNameMap, false); } - public static String replaceFields(String sql, Map fieldToBizName, boolean exactReplace) { + public static String replaceFields(String sql, Map fieldNameMap, boolean exactReplace) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); if (!(selectBody instanceof PlainSelect)) { @@ -78,7 +78,7 @@ public class SqlParserUpdateHelper { PlainSelect plainSelect = (PlainSelect) selectBody; //1. replace where fields Expression where = plainSelect.getWhere(); - FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldToBizName, exactReplace); + FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldNameMap, exactReplace); if (Objects.nonNull(where)) { where.accept(visitor); } @@ -92,14 +92,14 @@ public class SqlParserUpdateHelper { List orderByElements = plainSelect.getOrderByElements(); if (!CollectionUtils.isEmpty(orderByElements)) { for (OrderByElement orderByElement : orderByElements) { - orderByElement.accept(new OrderByReplaceVisitor(fieldToBizName, exactReplace)); + orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace)); } } //4. replace group by fields GroupByElement groupByElement = plainSelect.getGroupBy(); if (Objects.nonNull(groupByElement)) { - groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName, exactReplace)); + groupByElement.accept(new GroupByReplaceVisitor(fieldNameMap, exactReplace)); } //5. replace having fields Expression having = plainSelect.getHaving(); diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index 98f9ca666..7e88d99d8 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -36,5 +36,4 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.HavingCorrector, \ com.tencent.supersonic.chat.corrector.GroupByCorrector, \ - com.tencent.supersonic.chat.corrector.TableCorrector, \ com.tencent.supersonic.chat.corrector.GlobalAfterCorrector \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index cb441f5bd..f5f325a47 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -36,5 +36,4 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.HavingCorrector, \ com.tencent.supersonic.chat.corrector.GroupByCorrector, \ - com.tencent.supersonic.chat.corrector.TableCorrector, \ com.tencent.supersonic.chat.corrector.GlobalAfterCorrector \ No newline at end of file diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java index 520377951..7fe0399f1 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/QueryReqConverter.java @@ -1,6 +1,10 @@ package com.tencent.supersonic.semantic.query.parser.convert; +import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; +import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import com.tencent.supersonic.semantic.api.model.request.SqlExecuteReq; import com.tencent.supersonic.semantic.api.model.response.DatabaseResp; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; @@ -17,6 +21,7 @@ import com.tencent.supersonic.semantic.query.utils.QueryStructUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -31,6 +36,7 @@ import org.springframework.util.CollectionUtils; @Slf4j public class QueryReqConverter { + public static final String TABLE_PREFIX = "t_"; @Autowired private ModelService domainService; @Autowired @@ -41,38 +47,36 @@ public class QueryReqConverter { @Autowired private Catalog catalog; - public QueryStatement convert(QueryDslReq databaseReq, List domainSchemas) throws Exception { + public QueryStatement convert(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) throws Exception { List tables = new ArrayList<>(); MetricTable metricTable = new MetricTable(); - List allFields = SqlParserSelectHelper.getAllFields(databaseReq.getSql()); + if (Objects.isNull(modelSchemaResp)) { + return new QueryStatement(); + } + //1.convert name to bizName + convertNameToBizName(databaseReq, modelSchemaResp); + //2.functionName corrector + functionNameCorrector(databaseReq); + //3.correct tableName + correctTableName(databaseReq); + String tableName = SqlParserSelectHelper.getTableName(databaseReq.getSql()); - functionNameCorrector(databaseReq); - - if (CollectionUtils.isEmpty(domainSchemas) || StringUtils.isEmpty(tableName)) { + if (StringUtils.isEmpty(tableName)) { return new QueryStatement(); } - Set dimensions = domainSchemas.get(0).getDimensions().stream() - .map(entry -> entry.getBizName().toLowerCase()) - .collect(Collectors.toSet()); - dimensions.addAll(QueryStructUtils.internalCols); + List allFields = SqlParserSelectHelper.getAllFields(databaseReq.getSql()); - Set metrics = domainSchemas.get(0).getMetrics().stream().map(entry -> entry.getBizName().toLowerCase()) - .collect(Collectors.toSet()); + List metrics = getMetrics(modelSchemaResp, allFields); + metricTable.setMetrics(metrics); + + Set dimensions = getDimensions(modelSchemaResp, allFields); + + metricTable.setDimensions(new ArrayList<>(dimensions)); - metricTable.setMetrics(allFields.stream().filter(entry -> metrics.contains(entry.toLowerCase())) - .map(String::toLowerCase).collect(Collectors.toList())); - Set collect = allFields.stream().filter(entry -> dimensions.contains(entry.toLowerCase())) - .map(String::toLowerCase).collect(Collectors.toSet()); - for (String internalCol : QueryStructUtils.internalCols) { - if (databaseReq.getSql().contains(internalCol)) { - collect.add(internalCol); - } - } - metricTable.setDimensions(new ArrayList<>(collect)); metricTable.setAlias(tableName.toLowerCase()); // if metric empty , fill model default if (CollectionUtils.isEmpty(metricTable.getMetrics())) { @@ -92,6 +96,33 @@ public class QueryReqConverter { return queryStatement; } + private void convertNameToBizName(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) { + Map fieldNameToBizNameMap = getFieldNameToBizNameMap(modelSchemaResp); + String sql = databaseReq.getSql(); + log.info("convert name to bizName before:{}", sql); + String replaceFields = SqlParserUpdateHelper.replaceFields(sql, fieldNameToBizNameMap, true); + log.info("convert name to bizName after:{}", replaceFields); + databaseReq.setSql(replaceFields); + } + + private Set getDimensions(ModelSchemaResp modelSchemaResp, List allFields) { + Set allDimensions = modelSchemaResp.getDimensions().stream() + .map(entry -> entry.getBizName().toLowerCase()) + .collect(Collectors.toSet()); + allDimensions.addAll(QueryStructUtils.internalCols); + Set collect = allFields.stream().filter(entry -> allDimensions.contains(entry.toLowerCase())) + .map(String::toLowerCase).collect(Collectors.toSet()); + return collect; + } + + private List getMetrics(ModelSchemaResp modelSchemaResp, List allFields) { + Set allMetrics = modelSchemaResp.getMetrics().stream().map(entry -> entry.getBizName().toLowerCase()) + .collect(Collectors.toSet()); + List metrics = allFields.stream().filter(entry -> allMetrics.contains(entry.toLowerCase())) + .map(String::toLowerCase).collect(Collectors.toList()); + return metrics; + } + private void functionNameCorrector(QueryDslReq databaseReq) { DatabaseResp database = catalog.getDatabaseByModelId(databaseReq.getModelId()); if (Objects.isNull(database) || Objects.isNull(database.getType())) { @@ -107,4 +138,21 @@ public class QueryReqConverter { } } + + protected Map getFieldNameToBizNameMap(ModelSchemaResp modelSchemaResp) { + List allSchemaItems = new ArrayList<>(); + allSchemaItems.addAll(modelSchemaResp.getDimensions()); + allSchemaItems.addAll(modelSchemaResp.getMetrics()); + + Map result = allSchemaItems.stream() + .collect(Collectors.toMap(SchemaItem::getName, a -> a.getBizName(), (k1, k2) -> k1)); + result.put(DateUtils.DATE_FIELD, TimeDimensionEnum.DAY.getName()); + return result; + } + + public void correctTableName(QueryDslReq databaseReq) { + String sql = SqlParserUpdateHelper.replaceTable(databaseReq.getSql(), TABLE_PREFIX + databaseReq.getModelId()); + databaseReq.setSql(sql); + } + } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java index d39287f61..ea6c675a9 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java @@ -90,8 +90,11 @@ public class QueryServiceImpl implements QueryService { filter.setModelIds(modelIds); SchemaService schemaService = ContextUtils.getBean(SchemaService.class); List domainSchemas = schemaService.fetchModelSchema(filter, user); - - QueryStatement queryStatement = queryReqConverter.convert(querySqlCmd, domainSchemas); + ModelSchemaResp domainSchema = null; + if (CollectionUtils.isNotEmpty(domainSchemas)) { + domainSchema = domainSchemas.get(0); + } + QueryStatement queryStatement = queryReqConverter.convert(querySqlCmd, domainSchema); queryStatement.setModelId(querySqlCmd.getModelId()); return queryStatement; }