diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java index 05321e919..0ffe6b7ba 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.common.pojo.enums; import cn.hutool.core.collection.CollectionUtil; + import java.util.Arrays; import java.util.List; import java.util.Map; @@ -58,7 +59,7 @@ public enum TimeDimensionEnum { } /** - * Determine if a time dimension field is included in a Chinese text field + * Determine if a time dimension field is included in a Chinese/English text field * * @param fields field * @return true/false @@ -67,8 +68,6 @@ public enum TimeDimensionEnum { if (CollectionUtil.isEmpty(fields)) { return false; } - return fields.contains(TimeDimensionEnum.DAY.getChName()) - || fields.contains(TimeDimensionEnum.WEEK.getChName()) - || fields.contains(TimeDimensionEnum.MONTH.getChName()); + return fields.stream().anyMatch(field -> containsTimeDimension(field)); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java index 290605ac0..c8c503910 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java @@ -2,10 +2,15 @@ package com.tencent.supersonic.headless.server.rest.api; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; +import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector; +import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.server.service.QueryService; +import com.tencent.supersonic.headless.server.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -32,6 +37,7 @@ public class SqlQueryApiController { HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); + correct(querySqlReq); return queryService.queryByReq(querySqlReq, user); } @@ -45,8 +51,25 @@ public class SqlQueryApiController { QuerySqlReq querySqlReq = new QuerySqlReq(); BeanUtils.copyProperties(querySqlsReq, querySqlReq); querySqlReq.setSql(sql); + correct(querySqlReq); return querySqlReq; }).collect(Collectors.toList()); return queryService.queryByReqs(semanticQueryReqs, user); } + + private void correct(QuerySqlReq querySqlReq) { + QueryContext queryCtx = new QueryContext(); + SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); + SqlInfo sqlInfo = new SqlInfo(); + sqlInfo.setCorrectS2SQL(querySqlReq.getSql()); + sqlInfo.setS2SQL(querySqlReq.getSql()); + semanticParseInfo.setSqlInfo(sqlInfo); + + ComponentFactory.getSemanticCorrectors().forEach(corrector -> { + if (!(corrector instanceof GrammarCorrector)) { + corrector.correct(queryCtx, semanticParseInfo); + } + }); + querySqlReq.setSql(sqlInfo.getCorrectS2SQL()); + } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index 292815d66..9bf9d9809 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -180,6 +180,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { .modelIdToDataSetIds(modelIdToDataSetIds) .text2SQLType(queryReq.getText2SQLType()) .mapModeEnum(queryReq.getMapModeEnum()) + .dataSetIds(queryReq.getDataSetIds()) .build(); BeanUtils.copyProperties(queryReq, queryCtx); return queryCtx; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java index 26783fb84..b14b35c74 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java @@ -116,6 +116,9 @@ public class DataSetServiceImpl if (metaFilter.getName() != null) { wrapper.lambda().eq(DataSetDO::getName, metaFilter.getName()); } + if (!CollectionUtils.isEmpty(metaFilter.getNames())) { + wrapper.lambda().in(DataSetDO::getName, metaFilter.getNames()); + } wrapper.lambda().ne(DataSetDO::getStatus, StatusEnum.DELETED.getCode()); return list(wrapper).stream().map(entry -> convert(entry, user)).collect(Collectors.toList()); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java index a146d077d..b6ce4fb75 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java @@ -157,13 +157,13 @@ public class QueryReqConverter { return AggOption.DEFAULT; } - private void convertNameToBizName(QuerySqlReq databaseReq, SemanticSchemaResp semanticSchemaResp) { + private void convertNameToBizName(QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) { Map fieldNameToBizNameMap = getFieldNameToBizNameMap(semanticSchemaResp); - String sql = databaseReq.getSql(); - log.info("convert name to bizName before:{}", sql); + String sql = querySqlReq.getSql(); + log.info("dataSetId:{},convert name to bizName before:{}", querySqlReq.getDataSetId(), sql); String replaceFields = SqlReplaceHelper.replaceFields(sql, fieldNameToBizNameMap, true); - log.info("convert name to bizName after:{}", replaceFields); - databaseReq.setSql(replaceFields); + log.info("dataSetId:{},convert name to bizName after:{}", querySqlReq.getDataSetId(), replaceFields); + querySqlReq.setSql(replaceFields); } private Set getDimensions(SemanticSchemaResp semanticSchemaResp, List allFields) {