From 719b797037ff5233b3bdcd03e3e29a70ca3a7a45 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Mon, 9 Oct 2023 18:07:14 +0800 Subject: [PATCH] (improvement)(chat) support remove InExpression and partly complete fillResponse if queryResults exist primaryEntityBizName (#181) --- .../chat/api/pojo/response/ModelInfo.java | 1 + .../execute/EntityInfoExecuteResponder.java | 42 ++++++++- .../chat/service/SemanticService.java | 93 ++++++++++++------- .../common/util/jsqlparser/JsqlConstants.java | 2 + .../jsqlparser/SqlParserUpdateHelper.java | 60 ++++++++---- .../jsqlparser/SqlParserUpdateHelperTest.java | 10 ++ 6 files changed, 153 insertions(+), 55 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ModelInfo.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ModelInfo.java index 4600ca74b..b46c7b533 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ModelInfo.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ModelInfo.java @@ -8,5 +8,6 @@ import lombok.Data; public class ModelInfo extends DataInfo implements Serializable { private List words; + private String primaryEntityName; private String primaryEntityBizName; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java index 92dab6ec5..7c3c7a2fe 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java @@ -1,19 +1,57 @@ package com.tencent.supersonic.chat.responder.execute; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import org.apache.commons.collections.CollectionUtils; public class EntityInfoExecuteResponder implements ExecuteResponder { @Override public void fillResponse(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) { SemanticService semanticService = ContextUtils.getBean(SemanticService.class); - EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, queryReq.getUser()); - queryResult.setEntityInfo(entityInfo); + User user = queryReq.getUser(); + queryResult.setEntityInfo(semanticService.getEntityInfo(semanticParseInfo, user)); + + EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo.getModelId()); + if (Objects.isNull(entityInfo) || Objects.isNull(entityInfo.getModelInfo()) + || Objects.isNull(entityInfo.getModelInfo().getPrimaryEntityName())) { + return; + } + String primaryEntityBizName = entityInfo.getModelInfo().getPrimaryEntityBizName(); + boolean existPrimaryEntityName = queryResult.getQueryColumns().stream() + .anyMatch(queryColumn -> primaryEntityBizName.equals(queryColumn.getNameEn())); + + if (!existPrimaryEntityName) { + return; + } + List> queryResults = queryResult.getQueryResults(); + List entities = queryResults.stream() + .map(entry -> entry.get(primaryEntityBizName)) + .filter(Objects::nonNull) + .map(String::valueOf) + .collect(Collectors.toList()); + if (CollectionUtils.isEmpty(entities)) { + return; + } + QueryResultWithSchemaResp queryResultWithSchemaResp = semanticService.getQueryResultWithSchemaResp(entityInfo, + semanticParseInfo.getModelId(), entities, user); + if (Objects.isNull(queryResultWithSchemaResp)) { + return; + } + List> entityResultList = queryResultWithSchemaResp.getResultList(); + if (CollectionUtils.isEmpty(entityResultList)) { + return; + } } } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java index 7eb1b0e00..09bca998f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/SemanticService.java @@ -26,9 +26,9 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp; import com.tencent.supersonic.chat.api.pojo.response.DataInfo; -import com.tencent.supersonic.chat.api.pojo.response.ModelInfo; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.MetricInfo; +import com.tencent.supersonic.chat.api.pojo.response.ModelInfo; import com.tencent.supersonic.chat.config.AggregatorConfig; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; @@ -51,6 +51,7 @@ import java.time.YearMonth; import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.LinkedHashSet; @@ -62,6 +63,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -113,15 +115,12 @@ public class SemanticService { } } } - if (!"".equals(modelInfoId)) { - try { - setMainModel(entityInfo, parseInfo.getModelId(), - modelInfoId, user); - - return entityInfo; - } catch (Exception e) { - log.error("setMainModel error {}", e); - } + try { + setMainModel(entityInfo, parseInfo.getModelId(), + modelInfoId, user); + return entityInfo; + } catch (Exception e) { + log.error("setMainModel error {}", e); } } } @@ -152,6 +151,7 @@ public class SemanticService { modelInfo.setWords(modelSchema.getModel().getAlias()); modelInfo.setBizName(modelSchema.getModel().getBizName()); if (Objects.nonNull(modelSchema.getEntity())) { + modelInfo.setPrimaryEntityName(modelSchema.getEntity().getName()); modelInfo.setPrimaryEntityBizName(modelSchema.getEntity().getBizName()); } @@ -190,9 +190,40 @@ public class SemanticService { } public void setMainModel(EntityInfo modelInfo, Long model, String entity, User user) { - ModelSchema modelSchema = schemaService.getModelSchema(model); + if (StringUtils.isEmpty(entity)) { + return; + } - modelInfo.setEntityId(entity); + List entities = Collections.singletonList(entity); + + QueryResultWithSchemaResp queryResultWithColumns = getQueryResultWithSchemaResp(modelInfo, model, entities, + user); + + if (queryResultWithColumns != null) { + if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList()) + && queryResultWithColumns.getResultList().size() > 0) { + Map result = queryResultWithColumns.getResultList().get(0); + for (Map.Entry entry : result.entrySet()) { + String entryKey = getEntryKey(entry); + if (entry.getValue() == null || entryKey == null) { + continue; + } + modelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName())) + .forEach(i -> i.setValue(entry.getValue().toString())); + modelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName())) + .forEach(i -> i.setValue(entry.getValue().toString())); + } + } + } + } + + public QueryResultWithSchemaResp getQueryResultWithSchemaResp(EntityInfo modelInfo, Long model, + List entities, User user) { + if (CollectionUtils.isEmpty(entities)) { + return null; + } + ModelSchema modelSchema = schemaService.getModelSchema(model); + modelInfo.setEntityId(entities.get(0)); SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); semanticParseInfo.setModel(modelSchema.getModel()); semanticParseInfo.setNativeQuery(true); @@ -217,10 +248,7 @@ public class SemanticService { semanticParseInfo.setDateInfo(dateInfo); // add filter - QueryFilter chatFilter = new QueryFilter(); - chatFilter.setValue(String.valueOf(entity)); - chatFilter.setOperator(FilterOperatorEnum.EQUALS); - chatFilter.setBizName(getEntityPrimaryName(modelInfo)); + QueryFilter chatFilter = getQueryFilter(modelInfo, entities); Set chatFilters = new LinkedHashSet(); chatFilters.add(chatFilter); semanticParseInfo.setDimensionFilters(chatFilters); @@ -232,23 +260,20 @@ public class SemanticService { } catch (Exception e) { log.warn("setMainModel queryByStruct error, e:", e); } + return queryResultWithColumns; + } - if (queryResultWithColumns != null) { - if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList()) - && queryResultWithColumns.getResultList().size() > 0) { - Map result = queryResultWithColumns.getResultList().get(0); - for (Map.Entry entry : result.entrySet()) { - String entryKey = getEntryKey(entry); - if (entry.getValue() == null || entryKey == null) { - continue; - } - modelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName())) - .forEach(i -> i.setValue(entry.getValue().toString())); - modelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName())) - .forEach(i -> i.setValue(entry.getValue().toString())); - } - } + private QueryFilter getQueryFilter(EntityInfo modelInfo, List entities) { + QueryFilter chatFilter = new QueryFilter(); + if (entities.size() == 1) { + chatFilter.setValue(entities.get(0)); + chatFilter.setOperator(FilterOperatorEnum.EQUALS); + } else { + chatFilter.setValue(entities); + chatFilter.setOperator(FilterOperatorEnum.IN); } + chatFilter.setBizName(getEntityPrimaryName(modelInfo)); + return chatFilter; } private Set getDimensions(EntityInfo modelInfo) { @@ -332,7 +357,7 @@ public class SemanticService { } public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo, - QueryResultWithSchemaResp result) { + QueryResultWithSchemaResp result) { if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) || !aggregatorConfig.getEnableRatio()) { return new AggregateInfo(); } @@ -384,7 +409,7 @@ public class SemanticService { } private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric, - AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) { + AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) { MetricInfo metricInfo = new MetricInfo(); metricInfo.setStatistics(new HashMap<>()); QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum); @@ -432,7 +457,7 @@ public class SemanticService { } private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo, - QueryResultWithSchemaResp results) { + QueryResultWithSchemaResp results) { String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); Optional lastDayOp = results.getResultList().stream() .map(r -> r.get(dateField).toString()) diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java index 2d20478c0..273e46524 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/JsqlConstants.java @@ -15,4 +15,6 @@ public class JsqlConstants { public static final String GREATER_THAN_EQUALS_CONSTANT = " 1 >= 1 "; public static final String EQUAL_CONSTANT = " 1 = 1 "; + public static final String IN_CONSTANT = " 1 in (1) "; + } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index d8d294bfc..94a474abe 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -16,6 +16,7 @@ import net.sf.jsqlparser.expression.operators.conditional.XorExpression; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.expression.operators.relational.InExpression; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; @@ -526,29 +527,50 @@ public class SqlParserUpdateHelper { } private static void removeExpressionWithConstant(Expression expression, Set removeFieldNames) { - if (!(expression instanceof EqualsTo)) { - return; + if (expression instanceof EqualsTo) { + ComparisonOperator comparisonOperator = (ComparisonOperator) expression; + String columnName = getColumnName(comparisonOperator.getLeftExpression(), + comparisonOperator.getRightExpression()); + if (!removeFieldNames.contains(columnName)) { + return; + } + try { + ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression( + JsqlConstants.EQUAL_CONSTANT); + comparisonOperator.setLeftExpression(constantExpression.getLeftExpression()); + comparisonOperator.setRightExpression(constantExpression.getRightExpression()); + comparisonOperator.setASTNode(constantExpression.getASTNode()); + } catch (JSQLParserException e) { + log.error("JSQLParserException", e); + } } - ComparisonOperator comparisonOperator = (ComparisonOperator) expression; + if (expression instanceof InExpression) { + InExpression inExpression = (InExpression) expression; + String columnName = getColumnName(inExpression.getLeftExpression(), inExpression.getRightExpression()); + if (!removeFieldNames.contains(columnName)) { + return; + } + try { + InExpression constantExpression = (InExpression) CCJSqlParserUtil.parseCondExpression( + JsqlConstants.IN_CONSTANT); + inExpression.setLeftExpression(constantExpression.getLeftExpression()); + inExpression.setRightItemsList(constantExpression.getRightItemsList()); + inExpression.setASTNode(constantExpression.getASTNode()); + } catch (JSQLParserException e) { + log.error("JSQLParserException", e); + } + } + } + + private static String getColumnName(Expression leftExpression, Expression rightExpression) { String columnName = ""; - if (comparisonOperator.getRightExpression() instanceof Column) { - columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName(); + if (leftExpression instanceof Column) { + columnName = ((Column) leftExpression).getColumnName(); } - if (comparisonOperator.getLeftExpression() instanceof Column) { - columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName(); - } - if (!removeFieldNames.contains(columnName)) { - return; - } - try { - ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression( - JsqlConstants.EQUAL_CONSTANT); - comparisonOperator.setLeftExpression(constantExpression.getLeftExpression()); - comparisonOperator.setRightExpression(constantExpression.getRightExpression()); - comparisonOperator.setASTNode(constantExpression.getASTNode()); - } catch (JSQLParserException e) { - log.error("JSQLParserException", e); + if (rightExpression instanceof Column) { + columnName = ((Column) rightExpression).getColumnName(); } + return columnName; } } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index 930e1fa11..d1f9dcbcd 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -469,6 +469,16 @@ class SqlParserUpdateHelperTest { + "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋') and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames); + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND " + + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", + replaceSql); } private Map initParams() {