From ae9aa1ba0f49cab41bc0bda27d92d04521d32bb5 Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Tue, 31 Oct 2023 15:54:15 +0800 Subject: [PATCH] (improvement)(chat) fix in replace bug (#302) --- .../chat/corrector/GlobalBeforeCorrector.java | 2 +- .../chat/service/impl/QueryServiceImpl.java | 69 ++++++++++++----- .../jsqlparser/FieldlValueReplaceVisitor.java | 9 +++ .../jsqlparser/SqlParserSelectHelperTest.java | 5 ++ .../semantic/query/utils/DimValueAspect.java | 74 +++++++++++-------- 5 files changed, 111 insertions(+), 48 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java index 9b4934e59..ed56240f3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java @@ -93,4 +93,4 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false); semanticCorrectInfo.setSql(sql); } -} \ No newline at end of file +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 1c7d8d006..44a437a3c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.service.impl; import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SchemaMapper; +import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.ChatContext; @@ -34,6 +35,7 @@ import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.StatisticsService; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.SolvedQueryManager; +import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.util.ContextUtils; @@ -62,6 +64,8 @@ import java.util.Objects; import java.util.PriorityQueue; import java.util.Set; import java.util.stream.Collectors; + +import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.LongValue; @@ -247,7 +251,7 @@ public class QueryServiceImpl implements QueryService { queryResult.setQueryTimeCost(System.currentTimeMillis() - executeTime); return queryResult; } - + // save time cost data public void saveInfo(List timeCostDOList, String queryText, Long queryId, String userName, Long chatId) { @@ -324,7 +328,8 @@ public class QueryServiceImpl implements QueryService { ChatContext context = chatService.getOrCreateContext(queryCtx.getChatId()); return context.getParseInfo(); } - + //mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000", + //"style='流行'"->"style in ['流行','爱国']" @Override public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException { ChatParseDO chatParseDO = chatService.getParseInfo(queryData.getQueryId(), @@ -339,19 +344,21 @@ public class QueryServiceImpl implements QueryService { String correctorSql = parseInfo.getSqlInfo().getLogicSql(); log.info("correctorSql before replacing:{}", correctorSql); + // get where filter and having filter List whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql); List havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql); List addWhereConditions = new ArrayList<>(); List addHavingConditions = new ArrayList<>(); Set removeWhereFieldNames = new HashSet<>(); Set removeHavingFieldNames = new HashSet<>(); + // replace where filter updateFilters(filedNameToValueMap, whereExpressionList, queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames); updateDateInfo(queryData, parseInfo, filedNameToValueMap, whereExpressionList, addWhereConditions, removeWhereFieldNames); correctorSql = SqlParserReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); correctorSql = SqlParserRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames); - + // replace having filter updateFilters(havingFiledNameToValueMap, havingExpressionList, queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames); correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); @@ -399,15 +406,18 @@ public class QueryServiceImpl implements QueryService { queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1)); } + // startDate equals to endDate if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) { for (FilterExpression filterExpression : filterExpressionList) { if (DateUtils.DATE_FIELD.equals(filterExpression.getFieldName())) { + //sql where condition exists 'equals' operator about date,just replace if (filterExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) { dateField = filterExpression.getFieldName(); map.put(filterExpression.getFieldValue().toString(), queryData.getDateInfo().getStartDate()); filedNameToValueMap.put(dateField, map); } else { + // first remove,then add removeFieldNames.add(DateUtils.DATE_FIELD); EqualsTo equalsTo = new EqualsTo(); Column column = new Column(DateUtils.DATE_FIELD); @@ -423,6 +433,7 @@ public class QueryServiceImpl implements QueryService { for (FilterExpression filterExpression : filterExpressionList) { if (DateUtils.DATE_FIELD.equals(filterExpression.getFieldName())) { dateField = filterExpression.getFieldName(); + //just replace if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator()) || FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) { map.put(filterExpression.getFieldValue().toString(), @@ -434,12 +445,13 @@ public class QueryServiceImpl implements QueryService { queryData.getDateInfo().getEndDate()); } filedNameToValueMap.put(dateField, map); + // first remove,then add if (FilterOperatorEnum.EQUALS.getValue().equals(filterExpression.getOperator())) { removeFieldNames.add(DateUtils.DATE_FIELD); GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); - addTimeCondition(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions); + addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions); MinorThanEquals minorThanEquals = new MinorThanEquals(); - addTimeCondition(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions); + addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions); } } } @@ -447,7 +459,7 @@ public class QueryServiceImpl implements QueryService { parseInfo.setDateInfo(queryData.getDateInfo()); } - public void addTimeCondition(String date, + public void addTimeFilters(String date, T comparisonExpression, List addConditions) { Column column = new Column(DateUtils.DATE_FIELD); @@ -473,30 +485,30 @@ public class QueryServiceImpl implements QueryService { removeFieldNames.add(dslQueryFilter.getName()); if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) { EqualsTo equalsTo = new EqualsTo(); - addWhereCondition(dslQueryFilter, equalsTo, contextMetricFilters, addConditions); + addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions); } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) { GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); - addWhereCondition(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions); + addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions); } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) { GreaterThan greaterThan = new GreaterThan(); - addWhereCondition(dslQueryFilter, greaterThan, contextMetricFilters, addConditions); + addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions); } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) { MinorThanEquals minorThanEquals = new MinorThanEquals(); - addWhereCondition(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions); + addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions); } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) { MinorThan minorThan = new MinorThan(); - addWhereCondition(dslQueryFilter, minorThan, contextMetricFilters, addConditions); + addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions); } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) { InExpression inExpression = new InExpression(); - addWhereInCondition(dslQueryFilter, inExpression, contextMetricFilters, addConditions); + addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions); } break; } } } } - - public void addWhereInCondition(QueryFilter dslQueryFilter, + // add in condition to sql where condition + public void addWhereInFilters(QueryFilter dslQueryFilter, InExpression inExpression, Set contextMetricFilters, List addConditions) { @@ -523,8 +535,8 @@ public class QueryServiceImpl implements QueryService { } }); } - - public void addWhereCondition(QueryFilter dslQueryFilter, + // add where filter + public void addWhereFilters(QueryFilter dslQueryFilter, T comparisonExpression, Set contextMetricFilters, List addConditions) { @@ -580,7 +592,11 @@ public class QueryServiceImpl implements QueryService { Set detectModelIds = new HashSet<>(); detectModelIds.add(dimensionValueReq.getModelId()); List dimensionValues = getDimensionValues(dimensionValueReq, detectModelIds); - + // if the search results is null,search dimensionValue from database + if (CollectionUtils.isEmpty(dimensionValues)) { + queryResultWithSchemaResp = queryDatabase(dimensionValueReq, user); + return queryResultWithSchemaResp; + } List columns = new ArrayList<>(); QueryColumn queryColumn = new QueryColumn(); queryColumn.setNameEn(dimensionValueReq.getBizName()); @@ -628,5 +644,24 @@ public class QueryServiceImpl implements QueryService { .collect(Collectors.toList()); } + private QueryResultWithSchemaResp queryDatabase(DimensionValueReq dimensionValueReq, User user) { + QueryStructReq queryStructReq = new QueryStructReq(); + + DateConf dateConf = new DateConf(); + dateConf.setDateMode(DateConf.DateMode.RECENT); + dateConf.setUnit(1); + dateConf.setPeriod("DAY"); + queryStructReq.setDateInfo(dateConf); + queryStructReq.setLimit(20L); + queryStructReq.setModelId(dimensionValueReq.getModelId()); + queryStructReq.setNativeQuery(false); + List groups = new ArrayList<>(); + groups.add(dimensionValueReq.getBizName()); + queryStructReq.setGroups(groups); + SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); + QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user); + return queryResultWithSchemaResp; + } + } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java index d9e4b2eec..c333bcdca 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java @@ -59,6 +59,9 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { } public void visit(InExpression inExpression) { + if (!(inExpression.getLeftExpression() instanceof Column)) { + return; + } Column column = (Column) inExpression.getLeftExpression(); Map valueMap = filedNameToValueMap.get(column.getColumnName()); ExpressionList rightItemsList = (ExpressionList) inExpression.getRightItemsList(); @@ -69,7 +72,13 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { values.add(((StringValue) o).getValue()); } }); + if (valueMap == null) { + return; + } String value = valueMap.get(JsonUtil.toString(values)); + if (StringUtils.isBlank(value)) { + return; + } List valueList = JsonUtil.toList(value, String.class); List newExpressions = new ArrayList<>(); valueList.stream().forEach(o -> { diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index 6ebb9d555..f45ac2cb8 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.common.util.jsqlparser; import java.util.List; import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.statement.select.Select; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -13,6 +14,10 @@ class SqlParserSelectHelperTest { @Test void getWhereFilterExpression() { + Select selectStatement = SqlParserSelectHelper.getSelect( + "select 用户名, 访问次数 from 超音数 where 用户名 in ('alice', 'lucy')"); + System.out.println(selectStatement); + List filterExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE " + "sys_imp_date = '2023-08-08' AND YEAR(publish_date) = 2023 " diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java index 418df7b60..98577189f 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java @@ -56,6 +56,7 @@ public class DimValueAspect { QueryS2QLReq queryS2QLReq = (QueryS2QLReq) args[0]; String sql = queryS2QLReq.getSql(); log.info("correctorSql before replacing:{}", sql); + // if dimensionvalue is alias,consider the true dimensionvalue. List filterExpressionList = SqlParserSelectHelper.getWhereExpressions(sql); List dimensions = dimensionService.getDimensions(queryS2QLReq.getModelId()); Set fieldNames = dimensions.stream().map(o -> o.getName()).collect(Collectors.toSet()); @@ -63,42 +64,20 @@ public class DimValueAspect { filterExpressionList.stream().forEach(expression -> { if (fieldNames.contains(expression.getFieldName())) { dimensions.stream().forEach(dimension -> { - if (expression.getFieldName().equals(dimension.getName())) { - if (expression.getOperator().equals(FilterOperatorEnum.EQUALS.getValue()) - && !CollectionUtils.isEmpty(dimension.getDimValueMaps())) { + if (expression.getFieldName().equals(dimension.getName()) + && !CollectionUtils.isEmpty(dimension.getDimValueMaps())) { + // consider '=' filter + if (expression.getOperator().equals(FilterOperatorEnum.EQUALS.getValue())) { dimension.getDimValueMaps().stream().forEach(dimValue -> { if (!CollectionUtils.isEmpty(dimValue.getAlias()) && dimValue.getAlias().contains(expression.getFieldValue().toString())) { - Map map = new HashMap<>(); - map.put(expression.getFieldValue().toString(), dimValue.getTechName()); - filedNameToValueMap.put(expression.getFieldName(), map); + getFiledNameToValueMap(filedNameToValueMap, expression.getFieldValue().toString(), + dimValue.getTechName(), expression.getFieldName()); } }); } - if (expression.getOperator().equals(FilterOperatorEnum.IN.getValue())) { - String fieldValue = JsonUtil.toString(expression.getFieldValue()); - fieldValue = fieldValue.replace("'", ""); - List values = JsonUtil.toList(fieldValue, String.class); - List revisedValues = new ArrayList<>(); - for (int i = 0; i < values.size(); i++) { - Boolean flag = new Boolean(false); - for (DimValueMap dimValueMap : dimension.getDimValueMaps()) { - if (dimValueMap.getAlias().contains(values.get(i))) { - flag = true; - revisedValues.add(dimValueMap.getTechName()); - break; - } - } - if (!flag) { - revisedValues.add(values.get(i)); - } - } - if (!revisedValues.equals(values)) { - Map map = new HashMap<>(); - map.put(JsonUtil.toString(values), JsonUtil.toString(revisedValues)); - filedNameToValueMap.put(expression.getFieldName(), map); - } - } + // consider 'in' filter,each element needs to judge. + replaceInCondition(expression, dimension, filedNameToValueMap); } }); } @@ -116,6 +95,41 @@ public class DimValueAspect { return queryResultWithColumns; } + public void replaceInCondition(FilterExpression expression, DimensionResp dimension, + Map> filedNameToValueMap) { + if (expression.getOperator().equals(FilterOperatorEnum.IN.getValue())) { + String fieldValue = JsonUtil.toString(expression.getFieldValue()); + fieldValue = fieldValue.replace("'", ""); + List values = JsonUtil.toList(fieldValue, String.class); + List revisedValues = new ArrayList<>(); + for (int i = 0; i < values.size(); i++) { + Boolean flag = new Boolean(false); + for (DimValueMap dimValueMap : dimension.getDimValueMaps()) { + if (!CollectionUtils.isEmpty(dimValueMap.getAlias()) + && dimValueMap.getAlias().contains(values.get(i))) { + flag = true; + revisedValues.add(dimValueMap.getTechName()); + break; + } + } + if (!flag) { + revisedValues.add(values.get(i)); + } + } + if (!revisedValues.equals(values)) { + getFiledNameToValueMap(filedNameToValueMap, JsonUtil.toString(values), + JsonUtil.toString(revisedValues), expression.getFieldName()); + } + } + } + + public void getFiledNameToValueMap(Map> filedNameToValueMap, + String oldValue, String newValue, String fieldName) { + Map map = new HashMap<>(); + map.put(oldValue, newValue); + filedNameToValueMap.put(fieldName, map); + } + @Around("execution(* com.tencent.supersonic.semantic.query.rest.QueryController.queryByStruct(..))" + " || execution(* com.tencent.supersonic.semantic.query.service.QueryService.queryByStruct(..))"