From 2f812372d7e979627b52699f35c0e2fc686039d0 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 2 Nov 2023 20:58:45 +0800 Subject: [PATCH] add switch to translate S2QL into SQL (#314) --- .../chat/config/OptimizationConfig.java | 4 +- .../chat/mapper/EmbeddingMapper.java | 8 +- .../chat/mapper/FuzzyNameMapper.java | 3 +- .../llm/interpret/MetricInterpretQuery.java | 18 ++- .../chat/query/rule/RuleSemanticQuery.java | 11 +- .../execute/EntityInfoExecuteResponder.java | 1 + .../chat/service/SemanticService.java | 14 +- .../chat/service/impl/QueryServiceImpl.java | 4 + .../chat/utils/DictQueryHelper.java | 34 +++-- .../chat/utils/QueryReqBuilder.java | 119 --------------- .../chat/utils/QueryReqBuilderTest.java | 7 +- .../semantic/LocalSemanticInterpreter.java | 11 +- .../semantic/RemoteSemanticInterpreter.java | 7 + .../main/resources/optimization.properties | 1 + .../api/query/request/QueryStructReq.java | 138 ++++++++++++++++++ 15 files changed, 223 insertions(+), 157 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java index 0f5e10adc..5715979a7 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java @@ -8,7 +8,6 @@ import org.springframework.context.annotation.PropertySource; @Configuration @Data @PropertySource("classpath:optimization.properties") -//@ComponentScan(basePackages = "com.tencent.supersonic.chat") public class OptimizationConfig { @Value("${one.detection.size}") @@ -40,4 +39,7 @@ public class OptimizationConfig { @Value("${candidate.threshold}") private Double candidateThreshold; + @Value("${user.s2ql.switch:false}") + private boolean useS2qlSwitch; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java index 8db054d4b..28866424f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMapper.java @@ -4,13 +4,19 @@ import com.tencent.supersonic.chat.api.pojo.QueryContext; import lombok.extern.slf4j.Slf4j; /*** - * a mapper that is capable of semantic understanding of text. + * A mapper that is capable of semantic understanding of text. */ @Slf4j public class EmbeddingMapper extends BaseMapper { @Override public void work(QueryContext queryContext) { + //1. query from embedding by queryText + + + //2. build SchemaElementMatch by info + + //3. add to mapInfo } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyNameMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyNameMapper.java index bfb22f135..eb0548993 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyNameMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/FuzzyNameMapper.java @@ -171,11 +171,10 @@ public class FuzzyNameMapper extends BaseMapper { if (CollectionUtils.isEmpty(elements)) { return new HashSet<>(); } - Set regElementSet = elements.stream() + return elements.stream() .filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType())) .map(elementMatch -> elementMatch.getElement().getId()) .collect(Collectors.toSet()); - return regElementSet; } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/interpret/MetricInterpretQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/interpret/MetricInterpretQuery.java index 831344496..58306bdd3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/interpret/MetricInterpretQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/interpret/MetricInterpretQuery.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; @@ -20,6 +21,11 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.commons.lang3.StringUtils; @@ -27,12 +33,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import java.util.Map; -import java.util.HashMap; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - @Slf4j @Component public class MetricInterpretQuery extends PluginSemanticQuery { @@ -55,6 +55,10 @@ public class MetricInterpretQuery extends PluginSemanticQuery { fillAggregator(queryStructReq, parseInfo.getMetrics()); queryStructReq.setNativeQuery(true); SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); + + OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); + queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); + QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user); String text = generateTableText(queryResultWithSchemaResp); Map properties = parseInfo.getProperties(); @@ -76,7 +80,7 @@ public class MetricInterpretQuery extends PluginSemanticQuery { } private String replaceText(String text, List schemaElementMatches, - Map replacedMap) { + Map replacedMap) { if (CollectionUtils.isEmpty(schemaElementMatches)) { return text; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java index f88f81a27..06d0a62bd 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java @@ -14,16 +14,17 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; @@ -195,11 +196,17 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable { } QueryResult queryResult = new QueryResult(); - QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(convertQueryStruct(), user); + QueryStructReq queryStructReq = convertQueryStruct(); + + OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); + queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); + + QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user); if (queryResp != null) { queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); } + String sql = queryResp == null ? null : queryResp.getSql(); List> resultList = queryResp == null ? new ArrayList<>() : queryResp.getResultList(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java index c5d856a6c..cd2616c0f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java @@ -51,6 +51,7 @@ public class EntityInfoExecuteResponder implements ExecuteResponder { .filter(Objects::nonNull) .map(String::valueOf) .collect(Collectors.toList()); + if (CollectionUtils.isEmpty(entities)) { return; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java index 40147bfae..581b1f993 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java @@ -30,18 +30,19 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.MetricInfo; import com.tencent.supersonic.chat.api.pojo.response.ModelInfo; import com.tencent.supersonic.chat.config.AggregatorConfig; +import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.RatioOverType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.text.DecimalFormat; import java.time.DayOfWeek; @@ -262,8 +263,10 @@ public class SemanticService { QueryResultWithSchemaResp queryResultWithColumns = null; try { - queryResultWithColumns = semanticInterpreter.queryByStruct( - QueryReqBuilder.buildStructReq(semanticParseInfo), user); + QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(semanticParseInfo); + OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); + queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); + queryResultWithColumns = semanticInterpreter.queryByStruct(queryStructReq, user); } catch (Exception e) { log.warn("setMainModel queryByStruct error, e:", e); } @@ -425,7 +428,12 @@ public class SemanticService { queryStructReq.setGroups(new ArrayList<>(Arrays.asList(dateField))); queryStructReq.setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, results)); + + OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); + queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); + QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user); + if (Objects.nonNull(queryResp) && !CollectionUtils.isEmpty(queryResp.getResultList())) { Map result = queryResp.getResultList().get(0); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 6790e8e2c..d74f91550 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -20,6 +20,7 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; @@ -659,6 +660,9 @@ public class QueryServiceImpl implements QueryService { groups.add(dimensionValueReq.getBizName()); queryStructReq.setGroups(groups); SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); + + OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); + queryStructReq.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); QueryResultWithSchemaResp queryResultWithSchemaResp = semanticInterpreter.queryByStruct(queryStructReq, user); return queryResultWithSchemaResp; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/DictQueryHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/DictQueryHelper.java index b43149ab8..a6744e5a8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/DictQueryHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/DictQueryHelper.java @@ -1,20 +1,27 @@ package com.tencent.supersonic.chat.utils; +import static com.tencent.supersonic.common.pojo.Constants.AND_UPPER; +import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE; +import static com.tencent.supersonic.common.pojo.Constants.COMMA; +import static com.tencent.supersonic.common.pojo.Constants.SPACE; +import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE_DOUBLE; + import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.config.DefaultMetric; import com.tencent.supersonic.chat.config.Dim4Dict; -import com.tencent.supersonic.common.pojo.QueryColumn; -import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.pojo.Filter; -import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.common.pojo.Aggregator; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DateConf; +import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.common.pojo.Order; - +import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -22,19 +29,12 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.StringJoiner; - import lombok.extern.slf4j.Slf4j; import org.apache.logging.log4j.util.Strings; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import static com.tencent.supersonic.common.pojo.Constants.SPACE; -import static com.tencent.supersonic.common.pojo.Constants.AND_UPPER; -import static com.tencent.supersonic.common.pojo.Constants.COMMA; -import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE; -import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE_DOUBLE; - @Slf4j @Component public class DictQueryHelper { @@ -55,7 +55,11 @@ public class DictQueryHelper { List data = new ArrayList<>(); QueryStructReq queryStructCmd = generateQueryStructCmd(modelId, defaultMetricDesc, dim4Dict); try { + OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); + queryStructCmd.setUseS2qlSwitch(optimizationConfig.isUseS2qlSwitch()); + QueryResultWithSchemaResp queryResultWithColumns = semanticInterpreter.queryByStruct(queryStructCmd, user); + log.info("fetchDimValueSingle sql:{}", queryResultWithColumns.getSql()); String nature = String.format("_%d_%d", modelId, dim4Dict.getDimId()); String dimNameRewrite = rewriteDimName(queryResultWithColumns.getColumns(), dim4Dict.getBizName()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java index 5c7c62696..3358782b1 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java @@ -11,10 +11,6 @@ import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.DateModeUtils; -import com.tencent.supersonic.common.util.SqlFilterUtils; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq; @@ -22,7 +18,6 @@ import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.time.LocalDate; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; @@ -30,22 +25,6 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.JSQLParserException; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.Function; -import net.sf.jsqlparser.expression.LongValue; -import net.sf.jsqlparser.expression.operators.relational.ExpressionList; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; -import net.sf.jsqlparser.schema.Column; -import net.sf.jsqlparser.schema.Table; -import net.sf.jsqlparser.statement.select.GroupByElement; -import net.sf.jsqlparser.statement.select.Limit; -import net.sf.jsqlparser.statement.select.OrderByElement; -import net.sf.jsqlparser.statement.select.PlainSelect; -import net.sf.jsqlparser.statement.select.Select; -import net.sf.jsqlparser.statement.select.SelectExpressionItem; -import net.sf.jsqlparser.statement.select.SelectItem; -import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.springframework.beans.BeanUtils; import org.springframework.util.CollectionUtils; @@ -159,104 +138,6 @@ public class QueryReqBuilder { return queryS2QLReq; } - /** - * convert queryStructReq to QueryS2QLReq - * - * @param queryStructReq - * @return - */ - public static QueryS2QLReq buildS2QLReq(QueryStructReq queryStructReq) throws JSQLParserException { - Select select = new Select(); - //1.Set the select items (columns) - PlainSelect plainSelect = new PlainSelect(); - List selectItems = new ArrayList<>(); - List groups = queryStructReq.getGroups(); - if (!CollectionUtils.isEmpty(groups)) { - for (String group : groups) { - selectItems.add(new SelectExpressionItem(new Column(group))); - } - } - List aggregators = queryStructReq.getAggregators(); - if (!CollectionUtils.isEmpty(aggregators)) { - for (Aggregator aggregator : aggregators) { - if (queryStructReq.getNativeQuery()) { - selectItems.add(new SelectExpressionItem(new Column(aggregator.getColumn()))); - } else { - Function sumFunction = new Function(); - AggOperatorEnum func = aggregator.getFunc(); - if (AggOperatorEnum.UNKNOWN.equals(func)) { - func = AggOperatorEnum.SUM; - } - sumFunction.setName(func.getOperator()); - sumFunction.setParameters(new ExpressionList(new Column(aggregator.getColumn()))); - selectItems.add(new SelectExpressionItem(sumFunction)); - } - } - } - plainSelect.setSelectItems(selectItems); - //2.Set the table name - Table table = new Table(Constants.TABLE_PREFIX + queryStructReq.getModelId()); - plainSelect.setFromItem(table); - - //3.Set the order by clause - List orders = queryStructReq.getOrders(); - if (!CollectionUtils.isEmpty(orders)) { - List orderByElements = new ArrayList<>(); - for (Order order : orders) { - OrderByElement orderByElement = new OrderByElement(); - orderByElement.setExpression(new Column(order.getColumn())); - orderByElement.setAsc(false); - if (Constants.ASC_UPPER.equalsIgnoreCase(order.getDirection())) { - orderByElement.setAsc(true); - } - orderByElements.add(orderByElement); - } - plainSelect.setOrderByElements(orderByElements); - } - - //4.Set the group by clause - if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getNativeQuery()) { - GroupByElement groupByElement = new GroupByElement(); - for (String group : groups) { - groupByElement.addGroupByExpression(new Column(group)); - } - plainSelect.setGroupByElement(groupByElement); - } - - //7.Set the limit clause - if (Objects.nonNull(queryStructReq.getLimit())) { - Limit limit = new Limit(); - limit.setRowCount(new LongValue(queryStructReq.getLimit())); - plainSelect.setLimit(limit); - } - select.setSelectBody(plainSelect); - - //5.Set where - List dimensionFilters = queryStructReq.getDimensionFilters(); - SqlFilterUtils sqlFilterUtils = ContextUtils.getBean(SqlFilterUtils.class); - String whereClause = sqlFilterUtils.getWhereClause(dimensionFilters); - - String sql = select.toString(); - if (StringUtils.isNotBlank(whereClause)) { - Expression expression = CCJSqlParserUtil.parseCondExpression(whereClause); - sql = SqlParserAddHelper.addWhere(sql, expression); - } - - //6.Set DateInfo - DateModeUtils dateModeUtils = ContextUtils.getBean(DateModeUtils.class); - String dateWhereStr = dateModeUtils.getDateWhereStr(queryStructReq.getDateInfo()); - if (StringUtils.isNotBlank(dateWhereStr)) { - Expression expression = CCJSqlParserUtil.parseCondExpression(dateWhereStr); - sql = SqlParserAddHelper.addWhere(sql, expression); - } - - QueryS2QLReq result = new QueryS2QLReq(); - result.setSql(sql); - result.setModelId(queryStructReq.getModelId()); - result.setVariables(new HashMap<>()); - return result; - } - private static List getAggregatorByMetric(AggregateTypeEnum aggregateType, SchemaElement metric) { List aggregators = new ArrayList<>(); if (metric != null) { diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/utils/QueryReqBuilderTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/utils/QueryReqBuilderTest.java index 8a957605d..f59d65689 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/utils/QueryReqBuilderTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/utils/QueryReqBuilderTest.java @@ -14,7 +14,6 @@ import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import net.sf.jsqlparser.JSQLParserException; import org.junit.Assert; import org.junit.jupiter.api.Test; import org.mockito.MockedStatic; @@ -26,7 +25,7 @@ import org.mockito.Mockito; class QueryReqBuilderTest { @Test - void buildS2QLReq() throws JSQLParserException { + void buildS2QLReq() { init(); QueryStructReq queryStructReq = new QueryStructReq(); queryStructReq.setModelId(1L); @@ -50,13 +49,13 @@ class QueryReqBuilderTest { orders.add(order); queryStructReq.setOrders(orders); - QueryS2QLReq queryS2QLReq = QueryReqBuilder.buildS2QLReq(queryStructReq); + QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq); Assert.assertEquals( "SELECT department, SUM(pv) FROM t_1 WHERE (sys_imp_date IN ('2023-08-01')) " + "GROUP BY department ORDER BY uv LIMIT 2000", queryS2QLReq.getSql()); queryStructReq.setNativeQuery(true); - queryS2QLReq = QueryReqBuilder.buildS2QLReq(queryStructReq); + queryS2QLReq = queryStructReq.convert(queryStructReq); Assert.assertEquals( "SELECT department, pv FROM t_1 WHERE (sys_imp_date IN ('2023-08-01')) " + "ORDER BY uv LIMIT 2000", diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticInterpreter.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticInterpreter.java index fabbc8f39..c78871a7f 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticInterpreter.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticInterpreter.java @@ -27,6 +27,7 @@ import com.tencent.supersonic.semantic.query.service.SchemaService; import java.util.List; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; @Slf4j public class LocalSemanticInterpreter extends BaseSemanticInterpreter { @@ -39,6 +40,12 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter { @SneakyThrows @Override public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) { + QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq); + if (queryStructReq.isUseS2qlSwitch() && StringUtils.isNotBlank(queryS2QLReq.getSql())) { + log.info("queryStructReq convert to sql:{},queryStructReq:{}", queryS2QLReq.getSql(), queryStructReq); + return queryByS2QL(queryS2QLReq, user); + } + queryService = ContextUtils.getBean(QueryService.class); return queryService.queryByStructWithAuth(queryStructReq, user); } @@ -59,9 +66,7 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter { public QueryResultWithSchemaResp queryByS2QL(QueryS2QLReq queryS2QLReq, User user) { queryService = ContextUtils.getBean(QueryService.class); Object object = queryService.queryBySql(queryS2QLReq, user); - QueryResultWithSchemaResp queryResultWithSchemaResp = JsonUtil.toObject(JsonUtil.toString(object), - QueryResultWithSchemaResp.class); - return queryResultWithSchemaResp; + return JsonUtil.toObject(JsonUtil.toString(object), QueryResultWithSchemaResp.class); } @Override diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticInterpreter.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticInterpreter.java index 2fdbb4f86..e29007f66 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticInterpreter.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticInterpreter.java @@ -40,6 +40,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Objects; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.springframework.beans.BeanUtils; import org.springframework.core.ParameterizedTypeReference; @@ -68,6 +69,12 @@ public class RemoteSemanticInterpreter extends BaseSemanticInterpreter { @Override public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) { + QueryS2QLReq queryS2QLReq = queryStructReq.convert(queryStructReq); + if (queryStructReq.isUseS2qlSwitch() && StringUtils.isNotBlank(queryS2QLReq.getSql())) { + log.info("queryStructReq convert to sql:{},queryStructReq:{}", queryS2QLReq.getSql(), queryStructReq); + return queryByS2QL(queryS2QLReq, user); + } + DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class); return searchByRestTemplate( defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchByStructPath(), diff --git a/launchers/standalone/src/main/resources/optimization.properties b/launchers/standalone/src/main/resources/optimization.properties index bd4e14d4e..c0efcb683 100644 --- a/launchers/standalone/src/main/resources/optimization.properties +++ b/launchers/standalone/src/main/resources/optimization.properties @@ -8,3 +8,4 @@ long.text.threshold=0.8 short.text.threshold=0.5 query.text.length.threshold=10 candidate.threshold=0.2 +user.s2ql.switch=false \ No newline at end of file diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/QueryStructReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/QueryStructReq.java index 8f3d79d36..331724165 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/QueryStructReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/QueryStructReq.java @@ -1,6 +1,12 @@ package com.tencent.supersonic.semantic.api.query.request; import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.DateModeUtils; +import com.tencent.supersonic.common.util.SqlFilterUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.semantic.api.query.pojo.Cache; import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.semantic.api.query.pojo.Param; @@ -9,16 +15,36 @@ import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.Order; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.statement.select.GroupByElement; +import net.sf.jsqlparser.statement.select.Limit; +import net.sf.jsqlparser.statement.select.OrderByElement; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectExpressionItem; +import net.sf.jsqlparser.statement.select.SelectItem; import org.apache.commons.codec.digest.DigestUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.springframework.util.CollectionUtils; @Data +@Slf4j public class QueryStructReq { private Long modelId; @@ -34,6 +60,8 @@ public class QueryStructReq { private Boolean nativeQuery = false; private Cache cacheInfo; + private boolean useS2qlSwitch; + public List getGroups() { if (!CollectionUtils.isEmpty(this.groups)) { this.groups = groups.stream().filter(group -> !Strings.isEmpty(group)).collect(Collectors.toList()); @@ -129,4 +157,114 @@ public class QueryStructReq { sb.append('}'); return sb.toString(); } + + + /** + * convert queryStructReq to QueryS2QLReq + * + * @param queryStructReq + * @return + */ + public QueryS2QLReq convert(QueryStructReq queryStructReq) { + String sql = null; + try { + sql = buildSql(queryStructReq); + } catch (Exception e) { + log.error("buildSql error", e); + } + + QueryS2QLReq result = new QueryS2QLReq(); + result.setSql(sql); + result.setModelId(queryStructReq.getModelId()); + result.setVariables(new HashMap<>()); + return result; + } + + private String buildSql(QueryStructReq queryStructReq) throws JSQLParserException { + Select select = new Select(); + //1.Set the select items (columns) + PlainSelect plainSelect = new PlainSelect(); + List selectItems = new ArrayList<>(); + List groups = queryStructReq.getGroups(); + if (!CollectionUtils.isEmpty(groups)) { + for (String group : groups) { + selectItems.add(new SelectExpressionItem(new Column(group))); + } + } + List aggregators = queryStructReq.getAggregators(); + if (!CollectionUtils.isEmpty(aggregators)) { + for (Aggregator aggregator : aggregators) { + if (queryStructReq.getNativeQuery()) { + selectItems.add(new SelectExpressionItem(new Column(aggregator.getColumn()))); + } else { + Function sumFunction = new Function(); + AggOperatorEnum func = aggregator.getFunc(); + if (AggOperatorEnum.UNKNOWN.equals(func)) { + func = AggOperatorEnum.SUM; + } + sumFunction.setName(func.getOperator()); + sumFunction.setParameters(new ExpressionList(new Column(aggregator.getColumn()))); + selectItems.add(new SelectExpressionItem(sumFunction)); + } + } + } + plainSelect.setSelectItems(selectItems); + //2.Set the table name + Table table = new Table(Constants.TABLE_PREFIX + queryStructReq.getModelId()); + plainSelect.setFromItem(table); + + //3.Set the order by clause + List orders = queryStructReq.getOrders(); + if (!CollectionUtils.isEmpty(orders)) { + List orderByElements = new ArrayList<>(); + for (Order order : orders) { + OrderByElement orderByElement = new OrderByElement(); + orderByElement.setExpression(new Column(order.getColumn())); + orderByElement.setAsc(false); + if (Constants.ASC_UPPER.equalsIgnoreCase(order.getDirection())) { + orderByElement.setAsc(true); + } + orderByElements.add(orderByElement); + } + plainSelect.setOrderByElements(orderByElements); + } + + //4.Set the group by clause + if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getNativeQuery()) { + GroupByElement groupByElement = new GroupByElement(); + for (String group : groups) { + groupByElement.addGroupByExpression(new Column(group)); + } + plainSelect.setGroupByElement(groupByElement); + } + + //5.Set the limit clause + if (Objects.nonNull(queryStructReq.getLimit())) { + Limit limit = new Limit(); + limit.setRowCount(new LongValue(queryStructReq.getLimit())); + plainSelect.setLimit(limit); + } + select.setSelectBody(plainSelect); + + //6.Set where + List dimensionFilters = queryStructReq.getDimensionFilters(); + SqlFilterUtils sqlFilterUtils = ContextUtils.getBean(SqlFilterUtils.class); + String whereClause = sqlFilterUtils.getWhereClause(dimensionFilters); + + String sql = select.toString(); + if (StringUtils.isNotBlank(whereClause)) { + Expression expression = CCJSqlParserUtil.parseCondExpression(whereClause); + sql = SqlParserAddHelper.addWhere(sql, expression); + } + + //7.Set DateInfo + DateModeUtils dateModeUtils = ContextUtils.getBean(DateModeUtils.class); + String dateWhereStr = dateModeUtils.getDateWhereStr(queryStructReq.getDateInfo()); + if (StringUtils.isNotBlank(dateWhereStr)) { + Expression expression = CCJSqlParserUtil.parseCondExpression(dateWhereStr); + sql = SqlParserAddHelper.addWhere(sql, expression); + } + return sql; + } + }