From 3d1ca6ac1da7cc28aeb12ca7c32992a2d7ae64c7 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Wed, 7 Aug 2024 20:56:13 +0800 Subject: [PATCH] (improvement)(chat) Optimize the code for the queryData and queryDimensionValue interfaces. (#1529) --- .../service/impl/ChatQueryServiceImpl.java | 222 ++++++++++-------- ...tor.java => FieldValueReplaceVisitor.java} | 4 +- .../common/jsqlparser/SqlReplaceHelper.java | 23 +- .../common/pojo/enums/FilterOperatorEnum.java | 22 ++ .../service/impl/S2SemanticLayerService.java | 74 +++--- 5 files changed, 199 insertions(+), 146 deletions(-) rename common/src/main/java/com/tencent/supersonic/common/jsqlparser/{FieldlValueReplaceVisitor.java => FieldValueReplaceVisitor.java} (96%) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index a2d3cfb0b..656233a8b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq; +import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor; import com.tencent.supersonic.chat.server.parser.ChatQueryParser; @@ -22,7 +23,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; -import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.BeanMapper; @@ -33,13 +33,13 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.SearchResult; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; @@ -54,11 +54,8 @@ import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; -import net.sf.jsqlparser.expression.operators.relational.EqualsTo; -import net.sf.jsqlparser.expression.operators.relational.GreaterThan; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.InExpression; -import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList; import net.sf.jsqlparser.schema.Column; @@ -71,6 +68,7 @@ import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -195,32 +193,43 @@ public class ChatQueryServiceImpl implements ChatQueryService { return executeContext; } - //mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000", - //"style='流行'"->"style in ['流行','爱国']" @Override public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception { Integer parseId = chatQueryDataReq.getParseId(); - SemanticParseInfo parseInfo = chatManageService.getParseInfo( - chatQueryDataReq.getQueryId(), parseId); - parseInfo = mergeSemanticParseInfo(parseInfo, chatQueryDataReq); + SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId); + parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq); DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId()); SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); semanticQuery.setParseInfo(parseInfo); - List fields = new ArrayList<>(); - if (Objects.nonNull(parseInfo.getSqlInfo()) - && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { - String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); - fields = SqlSelectHelper.getAllSelectFields(correctorSql); + if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { + handleLLMQueryMode(chatQueryDataReq, semanticQuery, user); + } else { + handleRuleQueryMode(semanticQuery, dataSetSchema, user); } - if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode()) - && checkMetricReplace(fields, chatQueryDataReq.getMetrics())) { - //replace metrics + + return executeQuery(semanticQuery, user, dataSetSchema); + } + + private List getFieldsFromSql(SemanticParseInfo parseInfo) { + SqlInfo sqlInfo = parseInfo.getSqlInfo(); + if (Objects.isNull(sqlInfo) || StringUtils.isNotBlank(sqlInfo.getCorrectedS2SQL())) { + return new ArrayList<>(); + } + return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL()); + } + + private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq, + SemanticQuery semanticQuery, + User user) throws Exception { + SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); + List fields = getFieldsFromSql(parseInfo); + if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) { log.info("llm begin replace metrics!"); SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next(); replaceMetrics(parseInfo, metricToReplace); - } else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { + } else { log.info("llm begin revise filters!"); String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo); parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); @@ -228,16 +237,24 @@ public class ChatQueryServiceImpl implements ChatQueryService { SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user); parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); - } else { - log.info("rule begin replace metrics and revise filters!"); - //remove unvalid filters - validFilter(semanticQuery.getParseInfo().getDimensionFilters()); - validFilter(semanticQuery.getParseInfo().getMetricFilters()); - //init s2sql - semanticQuery.initS2Sql(dataSetSchema, user); } + } + + private void handleRuleQueryMode(SemanticQuery semanticQuery, + DataSetSchema dataSetSchema, + User user) { + log.info("rule begin replace metrics and revise filters!"); + validFilter(semanticQuery.getParseInfo().getDimensionFilters()); + validFilter(semanticQuery.getParseInfo().getMetricFilters()); + semanticQuery.initS2Sql(dataSetSchema, user); + } + + private QueryResult executeQuery(SemanticQuery semanticQuery, + User user, + DataSetSchema dataSetSchema) throws Exception { SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); - QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user); + SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); + QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user); queryResult.setChatContext(semanticQuery.getParseInfo()); SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user); @@ -246,10 +263,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private boolean checkMetricReplace(List oriFields, Set metrics) { - if (CollectionUtils.isEmpty(oriFields)) { - return false; - } - if (CollectionUtils.isEmpty(metrics)) { + if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) { return false; } List metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()); @@ -257,29 +271,30 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) { - Map> filedNameToValueMap = new HashMap<>(); - Map> havingFiledNameToValueMap = new HashMap<>(); - String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); log.info("correctorSql before replacing:{}", correctorSql); // get where filter and having filter List whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql); - List havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql); - List addWhereConditions = new ArrayList<>(); - List addHavingConditions = new ArrayList<>(); - Set removeWhereFieldNames = new HashSet<>(); - Set removeHavingFieldNames = new HashSet<>(); + // replace where filter - updateFilters(whereExpressionList, queryData.getDimensionFilters(), - parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames); - updateDateInfo(queryData, parseInfo, filedNameToValueMap, - whereExpressionList, addWhereConditions, removeWhereFieldNames); + List addWhereConditions = new ArrayList<>(); + Set removeWhereFieldNames = updateFilters(whereExpressionList, queryData.getDimensionFilters(), + parseInfo.getDimensionFilters(), addWhereConditions); + + Map> filedNameToValueMap = new HashMap<>(); + Set removeDataFieldNames = updateDateInfo(queryData, parseInfo, filedNameToValueMap, + whereExpressionList, addWhereConditions); + removeWhereFieldNames.addAll(removeDataFieldNames); + correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames); + // replace having filter - updateFilters(havingExpressionList, queryData.getDimensionFilters(), - parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames); - correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); + List havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql); + List addHavingConditions = new ArrayList<>(); + Set removeHavingFieldNames = updateFilters(havingExpressionList, + queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addHavingConditions); + correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>()); correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions); @@ -303,34 +318,32 @@ public class ChatQueryServiceImpl implements ChatQueryService { parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); } - private QueryResult doExecution(SemanticQueryReq semanticQueryReq, - SemanticParseInfo parseInfo, User user) throws Exception { + private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user) throws Exception { SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user); QueryResult queryResult = new QueryResult(); + if (queryResp != null) { queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); + queryResult.setQuerySql(queryResp.getSql()); + queryResult.setQueryResults(queryResp.getResultList()); + queryResult.setQueryColumns(queryResp.getColumns()); + } else { + queryResult.setQueryResults(new ArrayList<>()); + queryResult.setQueryColumns(new ArrayList<>()); } - String sql = queryResp == null ? null : queryResp.getSql(); - List> resultList = queryResp == null ? new ArrayList<>() - : queryResp.getResultList(); - List columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns(); - queryResult.setQuerySql(sql); - queryResult.setQueryResults(resultList); - queryResult.setQueryColumns(columns); - queryResult.setQueryMode(parseInfo.getQueryMode()); + queryResult.setQueryMode(queryMode); queryResult.setQueryState(QueryState.SUCCESS); - return queryResult; } - private void updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, - Map> filedNameToValueMap, - List fieldExpressionList, - List addConditions, - Set removeFieldNames) { + private Set updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, + Map> filedNameToValueMap, + List fieldExpressionList, + List addConditions) { + Set removeFieldNames = new HashSet<>(); if (Objects.isNull(queryData.getDateInfo())) { - return; + return removeFieldNames; } if (queryData.getDateInfo().getUnit() > 1) { queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); @@ -369,6 +382,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { } } parseInfo.setDateInfo(queryData.getDateInfo()); + return removeFieldNames; } private void addTimeFilters(String date, @@ -381,42 +395,41 @@ public class ChatQueryServiceImpl implements ChatQueryService { addConditions.add(comparisonExpression); } - private void updateFilters(List fieldExpressionList, - Set metricFilters, - Set contextMetricFilters, - List addConditions, - Set removeFieldNames) { - if (org.apache.commons.collections.CollectionUtils.isEmpty(metricFilters)) { - return; + private Set updateFilters(List fieldExpressionList, + Set metricFilters, + Set contextMetricFilters, + List addConditions) { + Set removeFieldNames = new HashSet<>(); + if (CollectionUtils.isEmpty(metricFilters)) { + return removeFieldNames; } + for (QueryFilter dslQueryFilter : metricFilters) { for (FieldExpression fieldExpression : fieldExpressionList) { if (fieldExpression.getFieldName() != null && fieldExpression.getFieldName().contains(dslQueryFilter.getName())) { removeFieldNames.add(dslQueryFilter.getName()); - if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) { - EqualsTo equalsTo = new EqualsTo(); - addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) { - GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); - addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) { - GreaterThan greaterThan = new GreaterThan(); - addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) { - MinorThanEquals minorThanEquals = new MinorThanEquals(); - addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) { - MinorThan minorThan = new MinorThan(); - addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions); - } else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) { - InExpression inExpression = new InExpression(); - addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions); - } + handleFilter(dslQueryFilter, contextMetricFilters, addConditions); break; } } } + return removeFieldNames; + } + + private void handleFilter(QueryFilter dslQueryFilter, + Set contextMetricFilters, + List addConditions) { + FilterOperatorEnum operator = dslQueryFilter.getOperator(); + + if (operator == FilterOperatorEnum.IN) { + addWhereInFilters(dslQueryFilter, new InExpression(), contextMetricFilters, addConditions); + } else { + ComparisonOperator expression = FilterOperatorEnum.createExpression(operator); + if (Objects.nonNull(expression)) { + addWhereFilters(dslQueryFilter, expression, contextMetricFilters, addConditions); + } + } } // add in condition to sql where condition @@ -428,7 +441,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>(); List valueList = JsonUtil.toList( JsonUtil.toString(dslQueryFilter.getValue()), String.class); - if (org.apache.commons.collections.CollectionUtils.isEmpty(valueList)) { + if (CollectionUtils.isEmpty(valueList)) { return; } valueList.stream().forEach(o -> { @@ -447,10 +460,10 @@ public class ChatQueryServiceImpl implements ChatQueryService { } // add where filter - private void addWhereFilters(QueryFilter dslQueryFilter, - T comparisonExpression, - Set contextMetricFilters, - List addConditions) { + private void addWhereFilters(QueryFilter dslQueryFilter, + ComparisonOperator comparisonExpression, + Set contextMetricFilters, + List addConditions) { String columnName = dslQueryFilter.getName(); if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")"; @@ -476,8 +489,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { }); } - private SemanticParseInfo mergeSemanticParseInfo(SemanticParseInfo parseInfo, - ChatQueryDataReq queryData) { + private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo, + ChatQueryDataReq queryData) { if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { return parseInfo; } @@ -500,13 +513,18 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void validFilter(Set filters) { - for (QueryFilter queryFilter : filters) { - if (Objects.isNull(queryFilter.getValue())) { - filters.remove(queryFilter); + Iterator iterator = filters.iterator(); + while (iterator.hasNext()) { + QueryFilter queryFilter = iterator.next(); + Object queryFilterValue = queryFilter.getValue(); + if (Objects.isNull(queryFilterValue)) { + iterator.remove(); + continue; } - if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty( - JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) { - filters.remove(queryFilter); + List collection = JsonUtil.toList(JsonUtil.toString(queryFilterValue), String.class); + if (FilterOperatorEnum.IN.equals(queryFilter.getOperator()) + && CollectionUtils.isEmpty(collection)) { + iterator.remove(); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldlValueReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java similarity index 96% rename from common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldlValueReplaceVisitor.java rename to common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java index 3a4d57821..11f070e8f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldlValueReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java @@ -25,13 +25,13 @@ import java.util.Map; import java.util.Objects; @Slf4j -public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { +public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); private boolean exactReplace; private Map> filedNameToValueMap; - public FieldlValueReplaceVisitor(boolean exactReplace, Map> filedNameToValueMap) { + public FieldValueReplaceVisitor(boolean exactReplace, Map> filedNameToValueMap) { this.exactReplace = exactReplace; this.filedNameToValueMap = filedNameToValueMap; } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index 2653ea50d..ac2b607ea 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -2,15 +2,6 @@ package com.tencent.supersonic.common.jsqlparser; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.util.StringUtil; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.function.UnaryOperator; - import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Alias; @@ -30,6 +21,7 @@ import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.statement.select.FromItem; import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.OrderByElement; @@ -40,11 +32,18 @@ import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.statement.select.SetOperationList; -import net.sf.jsqlparser.statement.select.FromItem; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.UnaryOperator; + /** * Sql Parser replace Helper */ @@ -132,7 +131,7 @@ public class SqlReplaceHelper { List plainSelects = SqlSelectHelper.getPlainSelect(selectStatement); for (PlainSelect plainSelect : plainSelects) { Expression where = plainSelect.getWhere(); - FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap); + FieldValueReplaceVisitor visitor = new FieldValueReplaceVisitor(exactReplace, filedNameToValueMap); if (Objects.nonNull(where)) { where.accept(visitor); } @@ -546,7 +545,7 @@ public class SqlReplaceHelper { } PlainSelect plainSelect = (PlainSelect) selectStatement; Expression having = plainSelect.getHaving(); - FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(false, filedNameToValueMap); + FieldValueReplaceVisitor visitor = new FieldValueReplaceVisitor(false, filedNameToValueMap); if (Objects.nonNull(having)) { having.accept(visitor); } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/FilterOperatorEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/FilterOperatorEnum.java index 01b720603..86b499667 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/FilterOperatorEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/FilterOperatorEnum.java @@ -3,6 +3,12 @@ package com.tencent.supersonic.common.pojo.enums; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.expression.operators.relational.MinorThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; public enum FilterOperatorEnum { IN("IN"), @@ -47,4 +53,20 @@ public enum FilterOperatorEnum { || MINOR_THAN_EQUALS.equals(filterOperatorEnum) || NOT_EQUALS.equals(filterOperatorEnum); } + public static ComparisonOperator createExpression(FilterOperatorEnum operator) { + switch (operator) { + case EQUALS: + return new EqualsTo(); + case GREATER_THAN_EQUALS: + return new GreaterThanEquals(); + case GREATER_THAN: + return new GreaterThan(); + case MINOR_THAN_EQUALS: + return new MinorThanEquals(); + case MINOR_THAN: + return new MinorThan(); + default: + return null; + } + } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java index 824ee0ca0..7394796d2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java @@ -202,50 +202,39 @@ public class S2SemanticLayerService implements SemanticLayerService { DimensionResp dimensionResp = getDimension(dimensionValueReq); Set dataSetIds = dimensionValueReq.getDataSetIds(); dimensionValueReq.setModelId(dimensionResp.getModelId()); + List dimensionValues = getDimensionValuesFromDict(dimensionValueReq, dataSetIds); - // if the search results is null,search dimensionValue from database + + // If the search results are null, search dimensionValue from the database if (CollectionUtils.isEmpty(dimensionValues)) { return getDimensionValuesFromDb(dimensionValueReq, user); } - List columns = new ArrayList<>(); - QueryColumn queryColumn = new QueryColumn(); - queryColumn.setNameEn(dimensionValueReq.getBizName()); - queryColumn.setShowType(SemanticType.CATEGORY.name()); - queryColumn.setAuthorized(true); - queryColumn.setType("CHAR"); - columns.add(queryColumn); - List> resultList = new ArrayList<>(); - dimensionValues.stream().forEach(o -> { - Map map = new HashMap<>(); - map.put(dimensionValueReq.getBizName(), o); - resultList.add(map); - }); + + List columns = createQueryColumns(dimensionValueReq); + List> resultList = createResultList(dimensionValueReq, dimensionValues); + semanticQueryResp.setColumns(columns); semanticQueryResp.setResultList(resultList); return semanticQueryResp; } private List getDimensionValuesFromDict(DimensionValueReq dimensionValueReq, Set dataSetIds) { - //if value is null ,then search from NATURE_TO_VALUES if (StringUtils.isBlank(dimensionValueReq.getValue())) { return SearchService.getDimensionValue(dimensionValueReq); } + Map> modelIdToDataSetIds = new HashMap<>(); modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds)); - //search from prefixSearch - List hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(), - 2000, modelIdToDataSetIds, dataSetIds); + + List hanlpMapResultList = knowledgeBaseService.prefixSearch( + dimensionValueReq.getValue(), 2000, modelIdToDataSetIds, dataSetIds); + HanlpHelper.transLetterOriginal(hanlpMapResultList); + return hanlpMapResultList.stream() - .filter(o -> { - for (String nature : o.getNatures()) { - Long elementID = NatureHelper.getElementID(nature); - if (dimensionValueReq.getElementID().equals(elementID)) { - return true; - } - } - return false; - }) + .filter(o -> o.getNatures().stream() + .map(NatureHelper::getElementID) + .anyMatch(elementID -> dimensionValueReq.getElementID().equals(elementID))) .map(MapResult::getName) .collect(Collectors.toList()); } @@ -255,11 +244,36 @@ public class S2SemanticLayerService implements SemanticLayerService { return queryByReq(querySqlReq, user); } + private List createQueryColumns(DimensionValueReq dimensionValueReq) { + QueryColumn queryColumn = new QueryColumn(); + queryColumn.setNameEn(dimensionValueReq.getBizName()); + queryColumn.setShowType(SemanticType.CATEGORY.name()); + queryColumn.setAuthorized(true); + queryColumn.setType("CHAR"); + + List columns = new ArrayList<>(); + columns.add(queryColumn); + return columns; + } + + private List> createResultList(DimensionValueReq dimensionValueReq, + List dimensionValues) { + return dimensionValues.stream() + .map(value -> { + Map map = new HashMap<>(); + map.put(dimensionValueReq.getBizName(), value); + return map; + }) + .collect(Collectors.toList()); + } + private DimensionResp getDimension(DimensionValueReq dimensionValueReq) { - DimensionResp dimensionResp = schemaService.getDimension(dimensionValueReq.getElementID()); + Long elementID = dimensionValueReq.getElementID(); + DimensionResp dimensionResp = schemaService.getDimension(elementID); if (dimensionResp == null) { - return schemaService.getDimension(dimensionValueReq.getBizName(), - dimensionValueReq.getModelId()); + String bizName = dimensionValueReq.getBizName(); + Long modelId = dimensionValueReq.getModelId(); + return schemaService.getDimension(bizName, modelId); } return dimensionResp; }