(improvement)(chat) logic sql show in chinese and convert to bizName in execute (#156)

This commit is contained in:
lexluo09
2023-09-27 17:27:31 +08:00
committed by GitHub
parent f931951ad5
commit 617db611c3
17 changed files with 138 additions and 100 deletions

View File

@@ -6,10 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
@@ -23,14 +23,11 @@ import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector { public abstract class BaseSemanticCorrector implements SemanticCorrector {
public static final String DATE_FIELD = "数据日期";
public void correct(SemanticCorrectInfo semanticCorrectInfo) { public void correct(SemanticCorrectInfo semanticCorrectInfo) {
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql()); semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
} }
protected Map<String, String> getFieldToBizName(Long modelId) { protected Map<String, String> getFieldNameMap(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
@@ -40,8 +37,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
Map<String, String> result = dbAllFields.stream() Map<String, String> result = dbAllFields.stream()
.filter(entry -> entry.getModel().equals(modelId)) .filter(entry -> entry.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1)); .collect(Collectors.toMap(SchemaElement::getName, a -> a.getName(), (k1, k2) -> k1));
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName()); result.put(DateUtils.DATE_FIELD, DateUtils.DATE_FIELD);
return result; return result;
} }
@@ -55,9 +52,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields); whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName()); whereFields.remove(DateUtils.DATE_FIELD);
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
semanticCorrectInfo.setSql(replaceFields); semanticCorrectInfo.setSql(replaceFields);
} }
@@ -75,7 +70,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name()); schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
} }
return schemaElement; return schemaElement;
}).collect(Collectors.toMap(a -> a.getBizName(), a -> a.getDefaultAgg(), (k1, k2) -> k1)); }).collect(Collectors.toMap(a -> a.getName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
if (CollectionUtils.isEmpty(metricToAggregate)) { if (CollectionUtils.isEmpty(metricToAggregate)) {
return; return;

View File

@@ -27,7 +27,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
updateFieldNameByLinkingValue(semanticCorrectInfo); updateFieldNameByLinkingValue(semanticCorrectInfo);
updateFieldNameByBizName(semanticCorrectInfo); correctFieldName(semanticCorrectInfo);
addAggregateToMetric(semanticCorrectInfo); addAggregateToMetric(semanticCorrectInfo);
} }
@@ -39,11 +39,11 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
semanticCorrectInfo.setSql(replaceAlias); semanticCorrectInfo.setSql(replaceAlias);
} }
private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) { private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) {
Map<String, String> fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()); Map<String, String> fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId());
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName); String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap);
semanticCorrectInfo.setSql(sql); semanticCorrectInfo.setSql(sql);
} }

View File

@@ -3,10 +3,10 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -28,8 +28,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream() Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
.filter(schemaElement -> !TimeDimensionEnum.DAY.getName().equals(schemaElement.getBizName())) .filter(schemaElement -> !DateUtils.DATE_FIELD.equals(schemaElement.getBizName()))
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet()); .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql); List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);

View File

@@ -25,7 +25,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(modelId).stream() Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
.map(schemaElement -> schemaElement.getBizName()).collect(Collectors.toSet()); .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) { if (CollectionUtils.isEmpty(metrics)) {
return; return;

View File

@@ -8,11 +8,11 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -70,9 +70,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) { private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql(); String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql); List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) { if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate); sql = SqlParserUpdateHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
} }
semanticCorrectInfo.setSql(sql); semanticCorrectInfo.setSql(sql);
} }
@@ -83,7 +83,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
} }
return queryFilters.getFilters().stream() return queryFilters.getFilters().stream()
.map(filter -> { .map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName()); String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap; return bizNameWrap + operatorWrap + valueWrap;
@@ -117,11 +117,11 @@ public class WhereCorrector extends BaseSemanticCorrector {
for (SchemaElement dimension : dimensions) { for (SchemaElement dimension : dimensions) {
if (Objects.isNull(dimension) if (Objects.isNull(dimension)
|| Strings.isEmpty(dimension.getBizName()) || Strings.isEmpty(dimension.getName())
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) { || CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
continue; continue;
} }
String bizName = dimension.getBizName(); String name = dimension.getName();
Map<String, String> aliasAndBizNameToTechName = new HashMap<>(); Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
@@ -141,7 +141,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
} }
} }
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) { if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
result.put(bizName, aliasAndBizNameToTechName); result.put(name, aliasAndBizNameToTechName);
} }
} }
return result; return result;

View File

@@ -16,7 +16,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMParserConfig; import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.corrector.BaseSemanticCorrector;
import com.tencent.supersonic.chat.parser.SatisfactionChecker; import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver; import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
@@ -31,11 +30,11 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@@ -113,7 +112,7 @@ public class LLMDslParser implements SemanticParser {
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) { private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream() return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()) .filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getBizName()) && allFields.contains(schemaElement.getName())
).collect(Collectors.toSet()); ).collect(Collectors.toSet());
} }
@@ -122,7 +121,7 @@ public class LLMDslParser implements SemanticParser {
return new ArrayList<>(); return new ArrayList<>();
} }
return allFields.stream() return allFields.stream()
.filter(entry -> !TimeDimensionEnum.getNameList().contains(entry)) .filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@@ -130,6 +129,7 @@ public class LLMDslParser implements SemanticParser {
String correctorSql = semanticCorrectInfo.getSql(); String correctorSql = semanticCorrectInfo.getSql();
parseInfo.getSqlInfo().setLogicSql(correctorSql); parseInfo.getSqlInfo().setLogicSql(correctorSql);
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
//set dataInfo //set dataInfo
try { try {
@@ -143,8 +143,8 @@ public class LLMDslParser implements SemanticParser {
//set filter //set filter
try { try {
Map<String, SchemaElement> bizNameToElement = getBizNameToElement(modelId); Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelId);
List<QueryFilter> result = getDimensionFilter(bizNameToElement, expressions); List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result); parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) { } catch (Exception e) {
log.error("set dimensionFilter error :", e); log.error("set dimensionFilter error :", e);
@@ -173,20 +173,18 @@ public class LLMDslParser implements SemanticParser {
} }
} }
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement, private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FilterExpression> filterExpressions) { List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList(); List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) { for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter(); QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue()); dimensionFilter.setValue(expression.getFieldValue());
String bizName = expression.getFieldName(); SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
SchemaElement schemaElement = bizNameToElement.get(bizName);
if (Objects.isNull(schemaElement)) { if (Objects.isNull(schemaElement)) {
continue; continue;
} }
String fieldName = schemaElement.getName(); dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setName(fieldName); dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setBizName(bizName);
dimensionFilter.setElementID(schemaElement.getId()); dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator()); FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
@@ -198,13 +196,8 @@ public class LLMDslParser implements SemanticParser {
private DateConf getDateInfo(List<FilterExpression> filterExpressions) { private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream() List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> { .filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName()))
List<String> nameList = TimeDimensionEnum.getNameList(); .collect(Collectors.toList());
if (StringUtils.isEmpty(expression.getFieldName())) {
return false;
}
return nameList.contains(expression.getFieldName().toLowerCase());
}).collect(Collectors.toList());
if (CollectionUtils.isEmpty(dateExpressions)) { if (CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf(); return new DateConf();
} }
@@ -354,7 +347,7 @@ public class LLMDslParser implements SemanticParser {
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig); List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD); fieldNameList.add(DateUtils.DATE_FIELD);
llmSchema.setFieldNameList(fieldNameList); llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema); llmReq.setSchema(llmSchema);
@@ -391,7 +384,7 @@ public class LLMDslParser implements SemanticParser {
} }
protected Map<String, SchemaElement> getBizNameToElement(Long modelId) { protected Map<String, SchemaElement> getNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions(); List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics(); List<SchemaElement> metrics = semanticSchema.getMetrics();
@@ -401,7 +394,7 @@ public class LLMDslParser implements SemanticParser {
allElements.addAll(metrics); allElements.addAll(metrics);
return allElements.stream() return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId)) .filter(schemaElement -> schemaElement.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getBizName, Function.identity(), (value1, value2) -> value2)); .collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2));
} }

View File

@@ -68,7 +68,7 @@ class LLMDslParserTest {
model.setId(2L); model.setId(2L);
parseInfo.setModel(model); parseInfo.setModel(model);
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ") .sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 ")
.parseInfo(parseInfo) .parseInfo(parseInfo)
.build(); .build();

View File

@@ -16,6 +16,7 @@ public class DateUtils {
public static final String DATE_FORMAT = "yyyy-MM-dd"; public static final String DATE_FORMAT = "yyyy-MM-dd";
public static final String DATE_FIELD = "数据日期";
public static final String TIME_FORMAT = "yyyy-MM-dd HH:mm:ss"; public static final String TIME_FORMAT = "yyyy-MM-dd HH:mm:ss";
public static Integer currentYear() { public static Integer currentYear() {

View File

@@ -9,16 +9,16 @@ import net.sf.jsqlparser.schema.Column;
public class FieldReplaceVisitor extends ExpressionVisitorAdapter { public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName; private Map<String, String> fieldNameMap;
private boolean exactReplace; private boolean exactReplace;
public FieldReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) { public FieldReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldToBizName = fieldToBizName; this.fieldNameMap = fieldNameMap;
this.exactReplace = exactReplace; this.exactReplace = exactReplace;
} }
@Override @Override
public void visit(Column column) { public void visit(Column column) {
parseVisitorHelper.replaceColumn(column, fieldToBizName, exactReplace); parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace);
} }
} }

View File

@@ -17,11 +17,11 @@ import org.apache.commons.lang3.StringUtils;
public class GroupByReplaceVisitor implements GroupByVisitor { public class GroupByReplaceVisitor implements GroupByVisitor {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName; private Map<String, String> fieldNameMap;
private boolean exactReplace; private boolean exactReplace;
public GroupByReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) { public GroupByReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldToBizName = fieldToBizName; this.fieldNameMap = fieldNameMap;
this.exactReplace = exactReplace; this.exactReplace = exactReplace;
} }
@@ -33,7 +33,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
for (int i = 0; i < groupByExpressions.size(); i++) { for (int i = 0; i < groupByExpressions.size(); i++) {
Expression expression = groupByExpressions.get(i); Expression expression = groupByExpressions.get(i);
String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName, String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldNameMap,
exactReplace); exactReplace);
if (StringUtils.isNotEmpty(replaceColumn)) { if (StringUtils.isNotEmpty(replaceColumn)) {
if (expression instanceof Column) { if (expression instanceof Column) {

View File

@@ -11,11 +11,11 @@ import net.sf.jsqlparser.statement.select.OrderByVisitorAdapter;
public class OrderByReplaceVisitor extends OrderByVisitorAdapter { public class OrderByReplaceVisitor extends OrderByVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName; private Map<String, String> fieldNameMap;
private boolean exactReplace; private boolean exactReplace;
public OrderByReplaceVisitor(Map<String, String> fieldToBizName, boolean exactReplace) { public OrderByReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldToBizName = fieldToBizName; this.fieldNameMap = fieldNameMap;
this.exactReplace = exactReplace; this.exactReplace = exactReplace;
} }
@@ -23,14 +23,14 @@ public class OrderByReplaceVisitor extends OrderByVisitorAdapter {
public void visit(OrderByElement orderBy) { public void visit(OrderByElement orderBy) {
Expression expression = orderBy.getExpression(); Expression expression = orderBy.getExpression();
if (expression instanceof Column) { if (expression instanceof Column) {
parseVisitorHelper.replaceColumn((Column) expression, fieldToBizName, exactReplace); parseVisitorHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace);
} }
if (expression instanceof Function) { if (expression instanceof Function) {
Function function = (Function) expression; Function function = (Function) expression;
List<Expression> expressions = function.getParameters().getExpressions(); List<Expression> expressions = function.getParameters().getExpressions();
for (Expression column : expressions) { for (Expression column : expressions) {
if (column instanceof Column) { if (column instanceof Column) {
parseVisitorHelper.replaceColumn((Column) column, fieldToBizName, exactReplace); parseVisitorHelper.replaceColumn((Column) column, fieldNameMap, exactReplace);
} }
} }
} }

View File

@@ -11,23 +11,23 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j @Slf4j
public class ParseVisitorHelper { public class ParseVisitorHelper {
public void replaceColumn(Column column, Map<String, String> fieldToBizName, boolean exactReplace) { public void replaceColumn(Column column, Map<String, String> fieldNameMap, boolean exactReplace) {
String columnName = column.getColumnName(); String columnName = column.getColumnName();
String replaceColumn = getReplaceColumn(columnName, fieldToBizName, exactReplace); String replaceColumn = getReplaceColumn(columnName, fieldNameMap, exactReplace);
if (StringUtils.isNotBlank(replaceColumn)) { if (StringUtils.isNotBlank(replaceColumn)) {
column.setColumnName(replaceColumn); column.setColumnName(replaceColumn);
} }
} }
public String getReplaceColumn(String columnName, Map<String, String> fieldToBizName, boolean exactReplace) { public String getReplaceColumn(String columnName, Map<String, String> fieldNameMap, boolean exactReplace) {
String fieldBizName = fieldToBizName.get(columnName); String fieldName = fieldNameMap.get(columnName);
if (StringUtils.isNotBlank(fieldBizName)) { if (StringUtils.isNotBlank(fieldName)) {
return fieldBizName; return fieldName;
} }
if (exactReplace) { if (exactReplace) {
return null; return null;
} }
Optional<Entry<String, String>> first = fieldToBizName.entrySet().stream().sorted((k1, k2) -> { Optional<Entry<String, String>> first = fieldNameMap.entrySet().stream().sorted((k1, k2) -> {
String k1FieldNameDb = k1.getKey(); String k1FieldNameDb = k1.getKey();
String k2FieldNameDb = k2.getKey(); String k2FieldNameDb = k2.getKey();
Double k1Similarity = getSimilarity(columnName, k1FieldNameDb); Double k1Similarity = getSimilarity(columnName, k1FieldNameDb);

View File

@@ -65,11 +65,11 @@ public class SqlParserUpdateHelper {
return selectStatement.toString(); return selectStatement.toString();
} }
public static String replaceFields(String sql, Map<String, String> fieldToBizName) { public static String replaceFields(String sql, Map<String, String> fieldNameMap) {
return replaceFields(sql, fieldToBizName, false); return replaceFields(sql, fieldNameMap, false);
} }
public static String replaceFields(String sql, Map<String, String> fieldToBizName, boolean exactReplace) { public static String replaceFields(String sql, Map<String, String> fieldNameMap, boolean exactReplace) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql); Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody(); SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) { if (!(selectBody instanceof PlainSelect)) {
@@ -78,7 +78,7 @@ public class SqlParserUpdateHelper {
PlainSelect plainSelect = (PlainSelect) selectBody; PlainSelect plainSelect = (PlainSelect) selectBody;
//1. replace where fields //1. replace where fields
Expression where = plainSelect.getWhere(); Expression where = plainSelect.getWhere();
FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldToBizName, exactReplace); FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldNameMap, exactReplace);
if (Objects.nonNull(where)) { if (Objects.nonNull(where)) {
where.accept(visitor); where.accept(visitor);
} }
@@ -92,14 +92,14 @@ public class SqlParserUpdateHelper {
List<OrderByElement> orderByElements = plainSelect.getOrderByElements(); List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
if (!CollectionUtils.isEmpty(orderByElements)) { if (!CollectionUtils.isEmpty(orderByElements)) {
for (OrderByElement orderByElement : orderByElements) { for (OrderByElement orderByElement : orderByElements) {
orderByElement.accept(new OrderByReplaceVisitor(fieldToBizName, exactReplace)); orderByElement.accept(new OrderByReplaceVisitor(fieldNameMap, exactReplace));
} }
} }
//4. replace group by fields //4. replace group by fields
GroupByElement groupByElement = plainSelect.getGroupBy(); GroupByElement groupByElement = plainSelect.getGroupBy();
if (Objects.nonNull(groupByElement)) { if (Objects.nonNull(groupByElement)) {
groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName, exactReplace)); groupByElement.accept(new GroupByReplaceVisitor(fieldNameMap, exactReplace));
} }
//5. replace having fields //5. replace having fields
Expression having = plainSelect.getHaving(); Expression having = plainSelect.getHaving();

View File

@@ -36,5 +36,4 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector, \ com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \ com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector com.tencent.supersonic.chat.corrector.GlobalAfterCorrector

View File

@@ -36,5 +36,4 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector, \ com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \ com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector com.tencent.supersonic.chat.corrector.GlobalAfterCorrector

View File

@@ -1,6 +1,10 @@
package com.tencent.supersonic.semantic.query.parser.convert; package com.tencent.supersonic.semantic.query.parser.convert;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.request.SqlExecuteReq; import com.tencent.supersonic.semantic.api.model.request.SqlExecuteReq;
import com.tencent.supersonic.semantic.api.model.response.DatabaseResp; import com.tencent.supersonic.semantic.api.model.response.DatabaseResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
@@ -17,6 +21,7 @@ import com.tencent.supersonic.semantic.query.utils.QueryStructUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -31,6 +36,7 @@ import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class QueryReqConverter { public class QueryReqConverter {
public static final String TABLE_PREFIX = "t_";
@Autowired @Autowired
private ModelService domainService; private ModelService domainService;
@Autowired @Autowired
@@ -41,38 +47,36 @@ public class QueryReqConverter {
@Autowired @Autowired
private Catalog catalog; private Catalog catalog;
public QueryStatement convert(QueryDslReq databaseReq, List<ModelSchemaResp> domainSchemas) throws Exception { public QueryStatement convert(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) throws Exception {
List<MetricTable> tables = new ArrayList<>(); List<MetricTable> tables = new ArrayList<>();
MetricTable metricTable = new MetricTable(); MetricTable metricTable = new MetricTable();
List<String> allFields = SqlParserSelectHelper.getAllFields(databaseReq.getSql()); if (Objects.isNull(modelSchemaResp)) {
return new QueryStatement();
}
//1.convert name to bizName
convertNameToBizName(databaseReq, modelSchemaResp);
//2.functionName corrector
functionNameCorrector(databaseReq);
//3.correct tableName
correctTableName(databaseReq);
String tableName = SqlParserSelectHelper.getTableName(databaseReq.getSql()); String tableName = SqlParserSelectHelper.getTableName(databaseReq.getSql());
functionNameCorrector(databaseReq); if (StringUtils.isEmpty(tableName)) {
if (CollectionUtils.isEmpty(domainSchemas) || StringUtils.isEmpty(tableName)) {
return new QueryStatement(); return new QueryStatement();
} }
Set<String> dimensions = domainSchemas.get(0).getDimensions().stream() List<String> allFields = SqlParserSelectHelper.getAllFields(databaseReq.getSql());
.map(entry -> entry.getBizName().toLowerCase())
.collect(Collectors.toSet());
dimensions.addAll(QueryStructUtils.internalCols);
Set<String> metrics = domainSchemas.get(0).getMetrics().stream().map(entry -> entry.getBizName().toLowerCase()) List<String> metrics = getMetrics(modelSchemaResp, allFields);
.collect(Collectors.toSet()); metricTable.setMetrics(metrics);
Set<String> dimensions = getDimensions(modelSchemaResp, allFields);
metricTable.setDimensions(new ArrayList<>(dimensions));
metricTable.setMetrics(allFields.stream().filter(entry -> metrics.contains(entry.toLowerCase()))
.map(String::toLowerCase).collect(Collectors.toList()));
Set<String> collect = allFields.stream().filter(entry -> dimensions.contains(entry.toLowerCase()))
.map(String::toLowerCase).collect(Collectors.toSet());
for (String internalCol : QueryStructUtils.internalCols) {
if (databaseReq.getSql().contains(internalCol)) {
collect.add(internalCol);
}
}
metricTable.setDimensions(new ArrayList<>(collect));
metricTable.setAlias(tableName.toLowerCase()); metricTable.setAlias(tableName.toLowerCase());
// if metric empty , fill model default // if metric empty , fill model default
if (CollectionUtils.isEmpty(metricTable.getMetrics())) { if (CollectionUtils.isEmpty(metricTable.getMetrics())) {
@@ -92,6 +96,33 @@ public class QueryReqConverter {
return queryStatement; return queryStatement;
} }
private void convertNameToBizName(QueryDslReq databaseReq, ModelSchemaResp modelSchemaResp) {
Map<String, String> fieldNameToBizNameMap = getFieldNameToBizNameMap(modelSchemaResp);
String sql = databaseReq.getSql();
log.info("convert name to bizName before:{}", sql);
String replaceFields = SqlParserUpdateHelper.replaceFields(sql, fieldNameToBizNameMap, true);
log.info("convert name to bizName after:{}", replaceFields);
databaseReq.setSql(replaceFields);
}
private Set<String> getDimensions(ModelSchemaResp modelSchemaResp, List<String> allFields) {
Set<String> allDimensions = modelSchemaResp.getDimensions().stream()
.map(entry -> entry.getBizName().toLowerCase())
.collect(Collectors.toSet());
allDimensions.addAll(QueryStructUtils.internalCols);
Set<String> collect = allFields.stream().filter(entry -> allDimensions.contains(entry.toLowerCase()))
.map(String::toLowerCase).collect(Collectors.toSet());
return collect;
}
private List<String> getMetrics(ModelSchemaResp modelSchemaResp, List<String> allFields) {
Set<String> allMetrics = modelSchemaResp.getMetrics().stream().map(entry -> entry.getBizName().toLowerCase())
.collect(Collectors.toSet());
List<String> metrics = allFields.stream().filter(entry -> allMetrics.contains(entry.toLowerCase()))
.map(String::toLowerCase).collect(Collectors.toList());
return metrics;
}
private void functionNameCorrector(QueryDslReq databaseReq) { private void functionNameCorrector(QueryDslReq databaseReq) {
DatabaseResp database = catalog.getDatabaseByModelId(databaseReq.getModelId()); DatabaseResp database = catalog.getDatabaseByModelId(databaseReq.getModelId());
if (Objects.isNull(database) || Objects.isNull(database.getType())) { if (Objects.isNull(database) || Objects.isNull(database.getType())) {
@@ -107,4 +138,21 @@ public class QueryReqConverter {
} }
} }
protected Map<String, String> getFieldNameToBizNameMap(ModelSchemaResp modelSchemaResp) {
List<SchemaItem> allSchemaItems = new ArrayList<>();
allSchemaItems.addAll(modelSchemaResp.getDimensions());
allSchemaItems.addAll(modelSchemaResp.getMetrics());
Map<String, String> result = allSchemaItems.stream()
.collect(Collectors.toMap(SchemaItem::getName, a -> a.getBizName(), (k1, k2) -> k1));
result.put(DateUtils.DATE_FIELD, TimeDimensionEnum.DAY.getName());
return result;
}
public void correctTableName(QueryDslReq databaseReq) {
String sql = SqlParserUpdateHelper.replaceTable(databaseReq.getSql(), TABLE_PREFIX + databaseReq.getModelId());
databaseReq.setSql(sql);
}
} }

View File

@@ -90,8 +90,11 @@ public class QueryServiceImpl implements QueryService {
filter.setModelIds(modelIds); filter.setModelIds(modelIds);
SchemaService schemaService = ContextUtils.getBean(SchemaService.class); SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
List<ModelSchemaResp> domainSchemas = schemaService.fetchModelSchema(filter, user); List<ModelSchemaResp> domainSchemas = schemaService.fetchModelSchema(filter, user);
ModelSchemaResp domainSchema = null;
QueryStatement queryStatement = queryReqConverter.convert(querySqlCmd, domainSchemas); if (CollectionUtils.isNotEmpty(domainSchemas)) {
domainSchema = domainSchemas.get(0);
}
QueryStatement queryStatement = queryReqConverter.convert(querySqlCmd, domainSchema);
queryStatement.setModelId(querySqlCmd.getModelId()); queryStatement.setModelId(querySqlCmd.getModelId());
return queryStatement; return queryStatement;
} }